mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Adding PrefixConstrainedLogitsProcessor (#8529)
* Adding PrefixConstrainedLogitsProcessor * fixing RAG and style_doc * fixing black (v20 instead of v19) * Improving doc in generation_logits_process.py * Improving docs and typing in generation_utils.py * docs improvement * adding test and fixing doc typo * fixing doc_len * isort on test * fixed test * improve docstring a bit Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
3bc1540070
commit
2f9d49b389
@ -13,8 +13,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Iterable, List
|
from typing import Callable, Iterable, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -372,3 +373,30 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||||||
)
|
)
|
||||||
scores = scores.masked_fill(banned_mask, -float("inf"))
|
scores = scores.masked_fill(banned_mask, -float("inf"))
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
|
||||||
|
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
|
||||||
|
information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
|
||||||
|
This function constraints the beam search to allowed tokens only at each step. This function takes 2
|
||||||
|
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
|
||||||
|
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
|
||||||
|
the batch ID :obj:`batch_id`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
|
||||||
|
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
|
||||||
|
self._num_beams = num_beams
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
mask = torch.full_like(scores, -math.inf)
|
||||||
|
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
|
||||||
|
for beam_id, sent in enumerate(beam_sent):
|
||||||
|
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
|
||||||
|
|
||||||
|
return scores + mask
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -26,6 +26,7 @@ from .generation_logits_process import (
|
|||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
@ -258,6 +259,8 @@ class GenerationMixin:
|
|||||||
bad_words_ids: List[List[int]],
|
bad_words_ids: List[List[int]],
|
||||||
min_length: int,
|
min_length: int,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||||
|
num_beams: int,
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||||
@ -285,6 +288,8 @@ class GenerationMixin:
|
|||||||
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
||||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||||
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
||||||
|
if prefix_allowed_tokens_fn is not None:
|
||||||
|
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -309,6 +314,7 @@ class GenerationMixin:
|
|||||||
num_return_sequences: Optional[int] = None,
|
num_return_sequences: Optional[int] = None,
|
||||||
decoder_start_token_id: Optional[int] = None,
|
decoder_start_token_id: Optional[int] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
r"""
|
r"""
|
||||||
@ -375,6 +381,13 @@ class GenerationMixin:
|
|||||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||||
speed up decoding.
|
speed up decoding.
|
||||||
|
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
|
||||||
|
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||||
|
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
|
||||||
|
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
|
||||||
|
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
||||||
|
argument is useful for constrained generation conditioned on the prefix, as described in
|
||||||
|
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
||||||
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
|
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
|
||||||
@ -494,6 +507,8 @@ class GenerationMixin:
|
|||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
num_beams=num_beams,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_greedy_gen_mode:
|
if is_greedy_gen_mode:
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
"""RAG model implementation."""
|
"""RAG model implementation."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -1229,6 +1229,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
decoder_start_token_id=None,
|
decoder_start_token_id=None,
|
||||||
n_docs=None,
|
n_docs=None,
|
||||||
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1302,6 +1303,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||||
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
||||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||||
|
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
|
||||||
|
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||||
|
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
|
||||||
|
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
|
||||||
|
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
||||||
|
argument is useful for constrained generation conditioned on the prefix, as described in
|
||||||
|
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||||
@ -1395,6 +1403,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
num_beams=num_beams,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_beams == 1:
|
if num_beams == 1:
|
||||||
|
@ -31,6 +31,7 @@ if is_torch_available():
|
|||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
@ -281,3 +282,23 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
# input_ids should never be changed
|
# input_ids should never be changed
|
||||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||||
|
|
||||||
|
def test_prefix_constrained_logits_processor(self):
|
||||||
|
vocab_size = 5
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
|
||||||
|
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
|
||||||
|
return [[0, 1], [2, 3]][batch_id]
|
||||||
|
|
||||||
|
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)
|
||||||
|
|
||||||
|
filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())
|
||||||
|
|
||||||
|
# batch 1: 1st, 2nd (0, 1) token are allowed
|
||||||
|
# batch 2: 3rd, 4th (2, 3) token are allowed
|
||||||
|
self.assertListEqual(
|
||||||
|
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user