diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 4b64a13fab5..dc6b183c4f5 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from abc import ABC -from typing import Iterable, List +from typing import Callable, Iterable, List import numpy as np import torch @@ -372,3 +373,30 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): ) scores = scores.masked_fill(banned_mask, -float("inf")) return scores + + +class PrefixConstrainedLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned + constrained generation. See `Autoregressive Entity Retrieval `__ for more + information. + + Args: + prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`): + This function constraints the beam search to allowed tokens only at each step. This function takes 2 + arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed + tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and + the batch ID :obj:`batch_id`. + """ + + def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): + self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self._num_beams = num_beams + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + mask = torch.full_like(scores, -math.inf) + for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): + for beam_id, sent in enumerate(beam_sent): + mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 + + return scores + mask diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 2e3aaa979da..6f99460ca5c 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import torch from torch.nn import functional as F @@ -26,6 +26,7 @@ from .generation_logits_process import ( MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, @@ -258,6 +259,8 @@ class GenerationMixin: bad_words_ids: List[List[int]], min_length: int, eos_token_id: int, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], + num_beams: int, ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant @@ -285,6 +288,8 @@ class GenerationMixin: processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > -1: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) + if prefix_allowed_tokens_fn is not None: + processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams)) return processors @torch.no_grad() @@ -309,6 +314,7 @@ class GenerationMixin: num_return_sequences: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **model_kwargs ) -> torch.LongTensor: r""" @@ -375,6 +381,13 @@ class GenerationMixin: use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID + :obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step + conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This + argument is useful for constrained generation conditioned on the prefix, as described in + `Autoregressive Entity Retrieval `__. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific @@ -494,6 +507,8 @@ class GenerationMixin: bad_words_ids=bad_words_ids, min_length=min_length, eos_token_id=eos_token_id, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + num_beams=num_beams, ) if is_greedy_gen_mode: diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index f32b3f036a8..31de9b3922b 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -15,7 +15,7 @@ """RAG model implementation.""" from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -1229,6 +1229,7 @@ class RagTokenForGeneration(RagPreTrainedModel): num_return_sequences=None, decoder_start_token_id=None, n_docs=None, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, **model_kwargs ): """ @@ -1302,6 +1303,13 @@ class RagTokenForGeneration(RagPreTrainedModel): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) Number of documents to retrieve and/or number of documents for which to generate an answer. + prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID + :obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step + conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This + argument is useful for constrained generation conditioned on the prefix, as described in + `Autoregressive Entity Retrieval `__. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated @@ -1395,6 +1403,8 @@ class RagTokenForGeneration(RagPreTrainedModel): bad_words_ids=bad_words_ids, min_length=min_length, eos_token_id=eos_token_id, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + num_beams=num_beams, ) if num_beams == 1: diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index bf3ee067b32..7dd0d055178 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -31,6 +31,7 @@ if is_torch_available(): MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, @@ -281,3 +282,23 @@ class LogitsProcessorTest(unittest.TestCase): # input_ids should never be changed self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) + + def test_prefix_constrained_logits_processor(self): + vocab_size = 5 + batch_size = 2 + + input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) + scores = self._get_uniform_logits(batch_size, vocab_size) + + def prefix_allowed_tokens_fn(batch_id, inputs_ids): + return [[0, 1], [2, 3]][batch_id] + + prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1) + + filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone()) + + # batch 1: 1st, 2nd (0, 1) token are allowed + # batch 2: 3rd, 4th (2, 3) token are allowed + self.assertListEqual( + torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]] + )