Adding PrefixConstrainedLogitsProcessor (#8529)

* Adding PrefixConstrainedLogitsProcessor

* fixing RAG and style_doc

* fixing black (v20 instead of v19)

* Improving doc in generation_logits_process.py

* Improving docs and typing in generation_utils.py

* docs improvement

* adding test and fixing doc typo

* fixing doc_len

* isort on test

* fixed test

* improve docstring a bit

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nicola De Cao 2020-11-18 16:06:25 +00:00 committed by GitHub
parent 3bc1540070
commit 2f9d49b389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 3 deletions

View File

@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from abc import ABC from abc import ABC
from typing import Iterable, List from typing import Callable, Iterable, List
import numpy as np import numpy as np
import torch import torch
@ -372,3 +373,30 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
) )
scores = scores.masked_fill(banned_mask, -float("inf")) scores = scores.masked_fill(banned_mask, -float("inf"))
return scores return scores
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
information.
Args:
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
This function constraints the beam search to allowed tokens only at each step. This function takes 2
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
the batch ID :obj:`batch_id`.
"""
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
return scores + mask

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
@ -26,6 +26,7 @@ from .generation_logits_process import (
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
@ -258,6 +259,8 @@ class GenerationMixin:
bad_words_ids: List[List[int]], bad_words_ids: List[List[int]],
min_length: int, min_length: int,
eos_token_id: int, eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
@ -285,6 +288,8 @@ class GenerationMixin:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1: if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
return processors return processors
@torch.no_grad() @torch.no_grad()
@ -309,6 +314,7 @@ class GenerationMixin:
num_return_sequences: Optional[int] = None, num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**model_kwargs **model_kwargs
) -> torch.LongTensor: ) -> torch.LongTensor:
r""" r"""
@ -375,6 +381,13 @@ class GenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding. speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
@ -494,6 +507,8 @@ class GenerationMixin:
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
) )
if is_greedy_gen_mode: if is_greedy_gen_mode:

View File

@ -15,7 +15,7 @@
"""RAG model implementation.""" """RAG model implementation."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
@ -1229,6 +1229,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_return_sequences=None, num_return_sequences=None,
decoder_start_token_id=None, decoder_start_token_id=None,
n_docs=None, n_docs=None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
**model_kwargs **model_kwargs
): ):
""" """
@ -1302,6 +1303,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
Return: Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
@ -1395,6 +1403,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
) )
if num_beams == 1: if num_beams == 1:

View File

@ -31,6 +31,7 @@ if is_torch_available():
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
@ -281,3 +282,23 @@ class LogitsProcessorTest(unittest.TestCase):
# input_ids should never be changed # input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def test_prefix_constrained_logits_processor(self):
vocab_size = 5
batch_size = 2
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
return [[0, 1], [2, 3]][batch_id]
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)
filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())
# batch 1: 1st, 2nd (0, 1) token are allowed
# batch 2: 3rd, 4th (2, 3) token are allowed
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)