mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
TF generate refactor - Sample (#15793)
* Add TF logits wrappers * Add sample method * add tests for TF logit wrappers * TF generate sample tests now run on CPU Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
96ae92be8c
commit
baab5e7cdf
@ -154,18 +154,30 @@ generation.
|
|||||||
[[autodoc]] TFLogitsProcessorList
|
[[autodoc]] TFLogitsProcessorList
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFTemperatureLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFTopPLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFTopKLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFMinLengthLogitsProcessor
|
[[autodoc]] TFMinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFNoBadWordsLogitsProcessor
|
[[autodoc]] TFNoBadWordsLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFNoRepeatNGramLogitsProcessor
|
[[autodoc]] TFNoRepeatNGramLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxLogitsProcessor
|
[[autodoc]] FlaxLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
@ -1656,10 +1656,14 @@ if is_tf_available():
|
|||||||
_import_structure["generation_tf_logits_process"] = [
|
_import_structure["generation_tf_logits_process"] = [
|
||||||
"TFLogitsProcessor",
|
"TFLogitsProcessor",
|
||||||
"TFLogitsProcessorList",
|
"TFLogitsProcessorList",
|
||||||
|
"TFLogitsWarper",
|
||||||
"TFMinLengthLogitsProcessor",
|
"TFMinLengthLogitsProcessor",
|
||||||
"TFNoBadWordsLogitsProcessor",
|
"TFNoBadWordsLogitsProcessor",
|
||||||
"TFNoRepeatNGramLogitsProcessor",
|
"TFNoRepeatNGramLogitsProcessor",
|
||||||
"TFRepetitionPenaltyLogitsProcessor",
|
"TFRepetitionPenaltyLogitsProcessor",
|
||||||
|
"TFTemperatureLogitsWarper",
|
||||||
|
"TFTopKLogitsWarper",
|
||||||
|
"TFTopPLogitsWarper",
|
||||||
]
|
]
|
||||||
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
||||||
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
|
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
|
||||||
@ -3706,10 +3710,14 @@ if TYPE_CHECKING:
|
|||||||
from .generation_tf_logits_process import (
|
from .generation_tf_logits_process import (
|
||||||
TFLogitsProcessor,
|
TFLogitsProcessor,
|
||||||
TFLogitsProcessorList,
|
TFLogitsProcessorList,
|
||||||
|
TFLogitsWarper,
|
||||||
TFMinLengthLogitsProcessor,
|
TFMinLengthLogitsProcessor,
|
||||||
TFNoBadWordsLogitsProcessor,
|
TFNoBadWordsLogitsProcessor,
|
||||||
TFNoRepeatNGramLogitsProcessor,
|
TFNoRepeatNGramLogitsProcessor,
|
||||||
TFRepetitionPenaltyLogitsProcessor,
|
TFRepetitionPenaltyLogitsProcessor,
|
||||||
|
TFTemperatureLogitsWarper,
|
||||||
|
TFTopKLogitsWarper,
|
||||||
|
TFTopPLogitsWarper,
|
||||||
)
|
)
|
||||||
from .generation_tf_utils import tf_top_k_top_p_filtering
|
from .generation_tf_utils import tf_top_k_top_p_filtering
|
||||||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||||
|
@ -94,7 +94,7 @@ class FlaxLogitsProcessorList(list):
|
|||||||
|
|
||||||
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
[`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
temperature (`float`):
|
temperature (`float`):
|
||||||
@ -114,7 +114,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
|||||||
|
|
||||||
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
@ -155,7 +155,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
|||||||
|
|
||||||
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
[`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_k (`int`):
|
top_k (`int`):
|
||||||
|
@ -326,7 +326,7 @@ class FlaxGenerationMixin:
|
|||||||
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self, top_k: int = None, top_p: float = None, temperature: float = None
|
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||||
) -> FlaxLogitsProcessorList:
|
) -> FlaxLogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
|
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
|
||||||
|
@ -58,6 +58,17 @@ class TFLogitsProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFLogitsWarper:
|
||||||
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||||
|
|
||||||
|
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""TF method for warping logits."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFLogitsProcessorList(list):
|
class TFLogitsProcessorList(list):
|
||||||
"""
|
"""
|
||||||
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
|
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
|
||||||
@ -81,6 +92,109 @@ class TFLogitsProcessorList(list):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TFTemperatureLogitsWarper(TFLogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float):
|
||||||
|
if not isinstance(temperature, float) or not (temperature > 0):
|
||||||
|
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
||||||
|
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||||
|
scores = scores / self.temperature
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TFTopKLogitsWarper(TFLogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
||||||
|
|
||||||
|
self.top_k = top_k
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||||
|
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
||||||
|
# Boolean mask containing all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
|
||||||
|
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
|
||||||
|
return next_scores
|
||||||
|
|
||||||
|
|
||||||
|
class TFTopPLogitsWarper(TFLogitsWarper):
|
||||||
|
"""
|
||||||
|
[`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
|
||||||
|
for generation.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||||
|
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||||
|
|
||||||
|
self.top_p = top_p
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||||
|
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
|
||||||
|
|
||||||
|
mask_scores = tf.fill(scores.shape, self.filter_value)
|
||||||
|
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
|
||||||
|
score_mask = cumulative_probs < self.top_p
|
||||||
|
|
||||||
|
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
|
||||||
|
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
|
||||||
|
|
||||||
|
# Ensure min tokens to keep
|
||||||
|
score_mask = tf.concat(
|
||||||
|
(
|
||||||
|
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
|
||||||
|
score_mask[:, self.min_tokens_to_keep :],
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mask the values that do not fit the criteria
|
||||||
|
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
|
||||||
|
|
||||||
|
# Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
|
||||||
|
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
|
||||||
|
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
|
||||||
|
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
|
||||||
|
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
|
||||||
|
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
|
||||||
|
|
||||||
|
return next_scores
|
||||||
|
|
||||||
|
|
||||||
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
|
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
|
||||||
|
@ -28,6 +28,9 @@ from .generation_tf_logits_process import (
|
|||||||
TFNoBadWordsLogitsProcessor,
|
TFNoBadWordsLogitsProcessor,
|
||||||
TFNoRepeatNGramLogitsProcessor,
|
TFNoRepeatNGramLogitsProcessor,
|
||||||
TFRepetitionPenaltyLogitsProcessor,
|
TFRepetitionPenaltyLogitsProcessor,
|
||||||
|
TFTemperatureLogitsWarper,
|
||||||
|
TFTopKLogitsWarper,
|
||||||
|
TFTopPLogitsWarper,
|
||||||
)
|
)
|
||||||
from .tf_utils import set_tensor_by_indices_to_value, shape_list
|
from .tf_utils import set_tensor_by_indices_to_value, shape_list
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
@ -558,9 +561,7 @@ class TFGenerationMixin:
|
|||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
|
|
||||||
is_greedy_gen_mode = num_beams == 1 and do_sample is False
|
if num_beams == 1:
|
||||||
|
|
||||||
if is_greedy_gen_mode:
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
@ -790,304 +791,34 @@ class TFGenerationMixin:
|
|||||||
cur_len < max_length
|
cur_len < max_length
|
||||||
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
|
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
|
||||||
|
|
||||||
if num_beams == 1:
|
return self._generate_beam_search(
|
||||||
return self._generate_no_beam_search(
|
input_ids,
|
||||||
input_ids,
|
cur_len=cur_len,
|
||||||
cur_len=cur_len,
|
max_length=max_length,
|
||||||
max_length=max_length,
|
min_length=min_length,
|
||||||
min_length=min_length,
|
do_sample=do_sample,
|
||||||
do_sample=do_sample,
|
early_stopping=early_stopping,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
vocab_size=vocab_size,
|
num_return_sequences=num_return_sequences,
|
||||||
encoder_outputs=encoder_outputs,
|
length_penalty=length_penalty,
|
||||||
attention_mask=attention_mask,
|
num_beams=num_beams,
|
||||||
use_cache=use_cache,
|
vocab_size=vocab_size,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
encoder_outputs=encoder_outputs,
|
||||||
**model_kwargs,
|
attention_mask=attention_mask,
|
||||||
)
|
use_cache=use_cache,
|
||||||
else:
|
forced_bos_token_id=forced_bos_token_id,
|
||||||
return self._generate_beam_search(
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
input_ids,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
cur_len=cur_len,
|
**model_kwargs,
|
||||||
max_length=max_length,
|
)
|
||||||
min_length=min_length,
|
|
||||||
do_sample=do_sample,
|
|
||||||
early_stopping=early_stopping,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
||||||
bad_words_ids=bad_words_ids,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
batch_size=effective_batch_size,
|
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
length_penalty=length_penalty,
|
|
||||||
num_beams=num_beams,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
encoder_outputs=encoder_outputs,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
forced_bos_token_id=forced_bos_token_id,
|
|
||||||
forced_eos_token_id=forced_eos_token_id,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_no_beam_search(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
cur_len,
|
|
||||||
max_length,
|
|
||||||
min_length,
|
|
||||||
do_sample,
|
|
||||||
temperature,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
no_repeat_ngram_size,
|
|
||||||
bad_words_ids,
|
|
||||||
pad_token_id,
|
|
||||||
eos_token_id,
|
|
||||||
batch_size,
|
|
||||||
vocab_size,
|
|
||||||
encoder_outputs,
|
|
||||||
attention_mask,
|
|
||||||
use_cache,
|
|
||||||
return_dict_in_generate,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[TFGreedySearchOutput, TFSampleOutput, tf.Tensor]:
|
|
||||||
"""
|
|
||||||
Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated
|
|
||||||
independently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# length of generated sentences / unfinished sentences
|
|
||||||
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
|
||||||
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
|
||||||
|
|
||||||
# defined for encoder-decoder models, None for decoder-only models
|
|
||||||
past = encoder_outputs
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
|
||||||
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
|
|
||||||
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
|
||||||
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
encoder_attentions = (
|
|
||||||
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
|
|
||||||
)
|
|
||||||
encoder_hidden_states = (
|
|
||||||
kwargs["encoder_hidden_states"]
|
|
||||||
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
while cur_len < max_length:
|
|
||||||
model_inputs = self.prepare_inputs_for_generation(
|
|
||||||
input_ids,
|
|
||||||
past=past,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
outputs = self(
|
|
||||||
**model_inputs,
|
|
||||||
return_dict=True,
|
|
||||||
output_attentions=kwargs["output_attentions"],
|
|
||||||
output_hidden_states=kwargs["output_hidden_states"],
|
|
||||||
)
|
|
||||||
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
|
||||||
if return_dict_in_generate:
|
|
||||||
if kwargs["output_scores"]:
|
|
||||||
scores += (next_token_logits,)
|
|
||||||
if kwargs["output_attentions"]:
|
|
||||||
decoder_attentions += (
|
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
||||||
)
|
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
cross_attentions += (outputs.cross_attentions,)
|
|
||||||
|
|
||||||
if kwargs["output_hidden_states"]:
|
|
||||||
decoder_hidden_states += (
|
|
||||||
(outputs.decoder_hidden_states,)
|
|
||||||
if self.config.is_encoder_decoder
|
|
||||||
else (outputs.hidden_states,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# if model has past, then set the past variable to speed up decoding
|
|
||||||
if self._use_cache(outputs, use_cache):
|
|
||||||
past = outputs[1]
|
|
||||||
|
|
||||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
|
||||||
if repetition_penalty != 1.0:
|
|
||||||
next_token_logits_penalties = _create_next_token_logits_penalties(
|
|
||||||
input_ids, next_token_logits, repetition_penalty
|
|
||||||
)
|
|
||||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
|
||||||
|
|
||||||
if no_repeat_ngram_size > 0:
|
|
||||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
||||||
banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
|
||||||
# create banned_tokens boolean mask
|
|
||||||
banned_tokens_indices_mask = []
|
|
||||||
for banned_tokens_slice in banned_tokens:
|
|
||||||
banned_tokens_indices_mask.append(
|
|
||||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_logits = set_tensor_by_indices_to_value(
|
|
||||||
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
if bad_words_ids is not None:
|
|
||||||
# calculate a list of banned tokens according to bad words
|
|
||||||
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
|
||||||
|
|
||||||
banned_tokens_indices_mask = []
|
|
||||||
for banned_tokens_slice in banned_tokens:
|
|
||||||
banned_tokens_indices_mask.append(
|
|
||||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_logits = set_tensor_by_indices_to_value(
|
|
||||||
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
|
||||||
if eos_token_id is not None and cur_len < min_length:
|
|
||||||
# create eos_token_id boolean mask
|
|
||||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
|
||||||
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
|
||||||
)
|
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
|
||||||
|
|
||||||
next_token_logits = set_tensor_by_indices_to_value(
|
|
||||||
next_token_logits, eos_token_indices_mask, -float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
if do_sample:
|
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
||||||
if temperature != 1.0:
|
|
||||||
next_token_logits = next_token_logits / temperature
|
|
||||||
# Top-p/top-k filtering
|
|
||||||
next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
|
||||||
# Sample
|
|
||||||
next_token = tf.squeeze(
|
|
||||||
tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Greedy decoding
|
|
||||||
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
|
||||||
|
|
||||||
# update generations and finished sentences
|
|
||||||
if eos_token_id is not None:
|
|
||||||
# pad finished sentences if eos_token_id exist
|
|
||||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
|
||||||
else:
|
|
||||||
tokens_to_add = next_token
|
|
||||||
|
|
||||||
# add token and increase length by one
|
|
||||||
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
|
||||||
cur_len = cur_len + 1
|
|
||||||
|
|
||||||
if eos_token_id is not None:
|
|
||||||
eos_in_sents = tokens_to_add == eos_token_id
|
|
||||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
|
||||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
|
||||||
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
|
|
||||||
)
|
|
||||||
sent_lengths = (
|
|
||||||
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
|
|
||||||
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
|
|
||||||
)
|
|
||||||
|
|
||||||
# unfinished_sents is set to zero if eos in sentence
|
|
||||||
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
|
||||||
|
|
||||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
|
||||||
if tf.math.reduce_max(unfinished_sents) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# extend attention_mask for new generated input if only decoder
|
|
||||||
if self.config.is_encoder_decoder is False:
|
|
||||||
attention_mask = tf.concat(
|
|
||||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
|
||||||
min_sent_length = tf.math.reduce_min(sent_lengths)
|
|
||||||
max_sent_length = tf.math.reduce_max(sent_lengths)
|
|
||||||
if min_sent_length != max_sent_length:
|
|
||||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
|
||||||
# finished sents are filled with pad_token
|
|
||||||
padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
|
|
||||||
|
|
||||||
# create length masks for tf.where operation
|
|
||||||
broad_casted_sent_lengths = tf.broadcast_to(
|
|
||||||
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
|
|
||||||
)
|
|
||||||
broad_casted_range = tf.transpose(
|
|
||||||
tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size])
|
|
||||||
)
|
|
||||||
|
|
||||||
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
|
|
||||||
else:
|
|
||||||
decoded = input_ids
|
|
||||||
|
|
||||||
if return_dict_in_generate:
|
|
||||||
if do_sample:
|
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
return TFSampleEncoderDecoderOutput(
|
|
||||||
sequences=decoded,
|
|
||||||
scores=scores,
|
|
||||||
encoder_attentions=encoder_attentions,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
decoder_attentions=decoder_attentions,
|
|
||||||
cross_attentions=cross_attentions,
|
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return TFSampleDecoderOnlyOutput(
|
|
||||||
sequences=decoded,
|
|
||||||
scores=scores,
|
|
||||||
attentions=decoder_attentions,
|
|
||||||
hidden_states=decoder_hidden_states,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
return TFGreedySearchEncoderDecoderOutput(
|
|
||||||
sequences=decoded,
|
|
||||||
scores=scores,
|
|
||||||
encoder_attentions=encoder_attentions,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
decoder_attentions=decoder_attentions,
|
|
||||||
cross_attentions=cross_attentions,
|
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return TFGreedySearchDecoderOnlyOutput(
|
|
||||||
sequences=decoded,
|
|
||||||
scores=scores,
|
|
||||||
attentions=decoder_attentions,
|
|
||||||
hidden_states=decoder_hidden_states,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return decoded
|
|
||||||
|
|
||||||
def _generate_beam_search(
|
def _generate_beam_search(
|
||||||
self,
|
self,
|
||||||
@ -1761,11 +1492,6 @@ class TFGenerationMixin:
|
|||||||
input_ids, return_dict_in_generate, model_kwargs
|
input_ids, return_dict_in_generate, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger
|
|
||||||
# refactor of all generation models in TF. `past` should be
|
|
||||||
# optional everywhere and not be set equal to encoder_outputs
|
|
||||||
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
|
||||||
|
|
||||||
# 4. Prepare `input_ids` which will be used for auto-regressive generation
|
# 4. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
||||||
@ -1787,6 +1513,7 @@ class TFGenerationMixin:
|
|||||||
# 5. determine generation mode
|
# 5. determine generation mode
|
||||||
# TODO(Matt, Joao, Patrick) - add more use cases here
|
# TODO(Matt, Joao, Patrick) - add more use cases here
|
||||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||||
|
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||||
|
|
||||||
# 6. prepare distribution pre_processing samplers
|
# 6. prepare distribution pre_processing samplers
|
||||||
logits_processor = self._get_logits_processor(
|
logits_processor = self._get_logits_processor(
|
||||||
@ -1804,6 +1531,10 @@ class TFGenerationMixin:
|
|||||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
|
||||||
|
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
|
||||||
|
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||||
|
|
||||||
# 8. run greedy search
|
# 8. run greedy search
|
||||||
return self.greedy_search(
|
return self.greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1816,6 +1547,35 @@ class TFGenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif is_sample_gen_mode:
|
||||||
|
# 8. prepare logits warper
|
||||||
|
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||||
|
|
||||||
|
# 9. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
input_ids,
|
||||||
|
expand_size=num_return_sequences,
|
||||||
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
|
||||||
|
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
|
||||||
|
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||||
|
|
||||||
|
# 10. run sample
|
||||||
|
return self.sample(
|
||||||
|
input_ids,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
logits_warper=logits_warper,
|
||||||
|
max_length=max_length,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
output_scores=output_scores,
|
||||||
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
|
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
|
||||||
|
|
||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
@ -1908,6 +1668,36 @@ class TFGenerationMixin:
|
|||||||
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_inputs_for_generation(
|
||||||
|
input_ids: tf.Tensor,
|
||||||
|
expand_size: int = 1,
|
||||||
|
is_encoder_decoder: bool = False,
|
||||||
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
encoder_outputs: Optional[ModelOutput] = None,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
||||||
|
expanded_return_idx = tf.reshape(
|
||||||
|
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1)
|
||||||
|
)
|
||||||
|
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
|
||||||
|
|
||||||
|
if "token_type_ids" in model_kwargs:
|
||||||
|
token_type_ids = model_kwargs["token_type_ids"]
|
||||||
|
model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx, axis=0)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx, axis=0)
|
||||||
|
|
||||||
|
if is_encoder_decoder:
|
||||||
|
if encoder_outputs is None:
|
||||||
|
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||||
|
encoder_outputs["last_hidden_state"] = tf.gather(
|
||||||
|
encoder_outputs.last_hidden_state, expanded_return_idx, axis=0
|
||||||
|
)
|
||||||
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||||
|
return input_ids, model_kwargs
|
||||||
|
|
||||||
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
|
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
|
||||||
# TODO(Patrick) - adapt this function when making `generate` more flexible
|
# TODO(Patrick) - adapt this function when making `generate` more flexible
|
||||||
# for all kinds of input types
|
# for all kinds of input types
|
||||||
@ -1956,6 +1746,34 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
def _get_logits_warper(
|
||||||
|
self,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> TFLogitsProcessorList:
|
||||||
|
"""
|
||||||
|
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
|
||||||
|
instances used for multinomial sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# init warp parameters
|
||||||
|
top_k = top_k if top_k is not None else self.config.top_k
|
||||||
|
top_p = top_p if top_p is not None else self.config.top_p
|
||||||
|
temperature = temperature if temperature is not None else self.config.temperature
|
||||||
|
# instantiate warpers list
|
||||||
|
warpers = TFLogitsProcessorList()
|
||||||
|
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
if temperature is not None and temperature != 1.0:
|
||||||
|
warpers.append(TFTemperatureLogitsWarper(temperature))
|
||||||
|
if top_k is not None and top_k != 0:
|
||||||
|
warpers.append(TFTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
warpers.append(TFTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
|
||||||
|
return warpers
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
self,
|
self,
|
||||||
repetition_penalty: float,
|
repetition_penalty: float,
|
||||||
@ -2029,8 +1847,8 @@ class TFGenerationMixin:
|
|||||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
Additional model specific keyword arguments will be forwarded to the `call` function of the model. If
|
||||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
[`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
|
[`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
|
||||||
@ -2043,13 +1861,13 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import (
|
>>> from transformers import (
|
||||||
... TFAutoTokenizer,
|
... AutoTokenizer,
|
||||||
... TFAutoModelForCausalLM,
|
... TFAutoModelForCausalLM,
|
||||||
... TFLogitsProcessorList,
|
... TFLogitsProcessorList,
|
||||||
... TFMinLengthLogitsProcessor,
|
... TFMinLengthLogitsProcessor,
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> tokenizer = TFAutoTokenizer.from_pretrained("gpt2")
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
|
||||||
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
||||||
@ -2195,6 +2013,230 @@ class TFGenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
input_ids: tf.Tensor,
|
||||||
|
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||||
|
logits_warper: Optional[TFLogitsProcessorList] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_scores: Optional[bool] = None,
|
||||||
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> Union[TFSampleOutput, tf.Tensor]:
|
||||||
|
r"""
|
||||||
|
Generates sequences for models with a language modeling head using multinomial sampling.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
The sequence used as a prompt for the generation.
|
||||||
|
logits_processor (`TFLogitsProcessorList`, *optional*):
|
||||||
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
|
||||||
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||||
|
logits_warper (`TFLogitsProcessorList`, *optional*):
|
||||||
|
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`]
|
||||||
|
used to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||||
|
sampling at each generation step.
|
||||||
|
max_length (`int`, *optional*, defaults to 20):
|
||||||
|
The maximum length of the sequence to be generated.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
The id of the *padding* token.
|
||||||
|
eos_token_id (`int`, *optional*):
|
||||||
|
The id of the *end-of-sequence* token.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more details.
|
||||||
|
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more details.
|
||||||
|
output_scores (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||||
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
model_kwargs:
|
||||||
|
Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an
|
||||||
|
encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
[`~generation_tf_utils.TFSampleDecoderOnlyOutput`], [`~generation_tf_utils.TFSampleEncoderDecoderOutput`]
|
||||||
|
or `tf.Tensor`: A `tf.Tensor` containing the generated tokens (default behaviour) or a
|
||||||
|
[`~generation_tf_utils.TFSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||||
|
`return_dict_in_generate=True` or a [`~generation_tf_utils.TFSampleEncoderDecoderOutput`] if
|
||||||
|
`model.config.is_encoder_decoder=True`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import (
|
||||||
|
... AutoTokenizer,
|
||||||
|
... TFAutoModelForCausalLM,
|
||||||
|
... TFLogitsProcessorList,
|
||||||
|
... TFMinLengthLogitsProcessor,
|
||||||
|
... TFTopKLogitsWarper,
|
||||||
|
... TFTemperatureLogitsWarper,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
||||||
|
>>> model.config.pad_token_id = model.config.eos_token_id
|
||||||
|
|
||||||
|
>>> input_prompt = "Today is a beautiful day, and"
|
||||||
|
>>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids
|
||||||
|
|
||||||
|
>>> # instantiate logits processors
|
||||||
|
>>> logits_processor = TFLogitsProcessorList(
|
||||||
|
... [
|
||||||
|
... TFMinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
>>> # instantiate logits processors
|
||||||
|
>>> logits_warper = TFLogitsProcessorList(
|
||||||
|
... [
|
||||||
|
... TFTopKLogitsWarper(50),
|
||||||
|
... TFTemperatureLogitsWarper(0.7),
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
|
||||||
|
|
||||||
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||||
|
```"""
|
||||||
|
|
||||||
|
# init values
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||||
|
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
||||||
|
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict_in_generate = (
|
||||||
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||||
|
)
|
||||||
|
|
||||||
|
# init attention / hidden states / scores tuples
|
||||||
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
|
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
|
||||||
|
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
|
||||||
|
# Remove the following lines when updating all encoder-decoder models
|
||||||
|
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
|
||||||
|
|
||||||
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||||
|
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
|
||||||
|
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
|
||||||
|
|
||||||
|
# keep track of which sequences are already finished
|
||||||
|
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
|
||||||
|
while cur_len < max_length:
|
||||||
|
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
|
||||||
|
# in all models
|
||||||
|
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
|
||||||
|
|
||||||
|
# prepare model inputs
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|
||||||
|
# forward pass to get next token
|
||||||
|
outputs = self(
|
||||||
|
**model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# pre-process distribution
|
||||||
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
|
# Store scores, attentions and hidden_states when required
|
||||||
|
if return_dict_in_generate:
|
||||||
|
if output_scores:
|
||||||
|
scores += (next_token_scores,)
|
||||||
|
if output_attentions:
|
||||||
|
decoder_attentions += (
|
||||||
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
decoder_hidden_states += (
|
||||||
|
(outputs.decoder_hidden_states,)
|
||||||
|
if self.config.is_encoder_decoder
|
||||||
|
else (outputs.hidden_states,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# sample
|
||||||
|
next_tokens = tf.squeeze(
|
||||||
|
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# finished sentences should have their next token be a padding token
|
||||||
|
if eos_token_id is not None:
|
||||||
|
if pad_token_id is None:
|
||||||
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||||
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step
|
||||||
|
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
|
||||||
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
|
)
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
|
if eos_token_id is not None:
|
||||||
|
eos_in_sents = next_tokens == eos_token_id
|
||||||
|
# if sentence is unfinished and the token to add is eos
|
||||||
|
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||||
|
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
# unfinished_sequences is set to zero if eos in sentence
|
||||||
|
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
|
||||||
|
|
||||||
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
|
if tf.math.reduce_max(unfinished_sequences) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if return_dict_in_generate:
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
return TFSampleEncoderDecoderOutput(
|
||||||
|
sequences=input_ids,
|
||||||
|
scores=scores,
|
||||||
|
encoder_attentions=encoder_attentions,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TFSampleDecoderOnlyOutput(
|
||||||
|
sequences=input_ids,
|
||||||
|
scores=scores,
|
||||||
|
attentions=decoder_attentions,
|
||||||
|
hidden_states=decoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||||
# create logit penalties for already seen input_ids
|
# create logit penalties for already seen input_ids
|
||||||
@ -2292,7 +2334,6 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
|
|||||||
# Remove all tokens with a probability less than the last token of the top-k
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
||||||
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||||
|
|
||||||
if top_p < 1.0:
|
if top_p < 1.0:
|
||||||
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
||||||
sorted_logits = tf.gather(
|
sorted_logits = tf.gather(
|
||||||
|
@ -556,8 +556,8 @@ class GenerationMixin:
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
expand_size: int = 1,
|
expand_size: int = 1,
|
||||||
is_encoder_decoder: bool = False,
|
is_encoder_decoder: bool = False,
|
||||||
attention_mask: torch.LongTensor = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
encoder_outputs: ModelOutput = None,
|
encoder_outputs: Optional[ModelOutput] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||||
expanded_return_idx = (
|
expanded_return_idx = (
|
||||||
@ -617,11 +617,11 @@ class GenerationMixin:
|
|||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self,
|
self,
|
||||||
top_k: int = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: float = None,
|
top_p: Optional[float] = None,
|
||||||
typical_p: float = None,
|
typical_p: Optional[float] = None,
|
||||||
temperature: float = None,
|
temperature: Optional[float] = None,
|
||||||
num_beams: int = None,
|
num_beams: Optional[int] = None,
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||||
|
@ -31,6 +31,13 @@ class TFLogitsProcessorList(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFLogitsWarper(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFMinLengthLogitsProcessor(metaclass=DummyObject):
|
class TFMinLengthLogitsProcessor(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
@ -59,6 +66,27 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFTemperatureLogitsWarper(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFTopKLogitsWarper(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFTopPLogitsWarper(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
def tf_top_k_top_p_filtering(*args, **kwargs):
|
def tf_top_k_top_p_filtering(*args, **kwargs):
|
||||||
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def test_min_lenght_dist_processor(self):
|
def test_min_length_dist_processor(self):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
eos_token_id = 0
|
eos_token_id = 0
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import require_tf
|
from transformers.testing_utils import require_tf
|
||||||
|
|
||||||
@ -29,6 +31,9 @@ if is_tf_available():
|
|||||||
TFNoBadWordsLogitsProcessor,
|
TFNoBadWordsLogitsProcessor,
|
||||||
TFNoRepeatNGramLogitsProcessor,
|
TFNoRepeatNGramLogitsProcessor,
|
||||||
TFRepetitionPenaltyLogitsProcessor,
|
TFRepetitionPenaltyLogitsProcessor,
|
||||||
|
TFTemperatureLogitsWarper,
|
||||||
|
TFTopKLogitsWarper,
|
||||||
|
TFTopPLogitsWarper,
|
||||||
)
|
)
|
||||||
from transformers.tf_utils import set_tensor_by_indices_to_value
|
from transformers.tf_utils import set_tensor_by_indices_to_value
|
||||||
|
|
||||||
@ -38,7 +43,7 @@ if is_tf_available():
|
|||||||
@require_tf
|
@require_tf
|
||||||
class TFLogitsProcessorTest(unittest.TestCase):
|
class TFLogitsProcessorTest(unittest.TestCase):
|
||||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
scores = np.ones((batch_size, length), dtype=np.float32) / length
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def test_min_length_dist_processor(self):
|
def test_min_length_dist_processor(self):
|
||||||
@ -60,6 +65,37 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
|
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
|
||||||
|
|
||||||
|
def test_temperature_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
length = 20
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||||
|
|
||||||
|
# tweak scores to not be uniform anymore
|
||||||
|
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||||
|
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||||
|
|
||||||
|
# compute softmax
|
||||||
|
probs = tf.nn.softmax(scores, axis=-1)
|
||||||
|
|
||||||
|
temp_dist_warper_sharper = TFTemperatureLogitsWarper(temperature=0.5)
|
||||||
|
temp_dist_warper_smoother = TFTemperatureLogitsWarper(temperature=1.3)
|
||||||
|
|
||||||
|
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
|
||||||
|
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
|
||||||
|
|
||||||
|
# uniform distribution stays uniform
|
||||||
|
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
|
||||||
|
tf.debugging.assert_near(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)
|
||||||
|
|
||||||
|
# sharp peaks get higher, valleys get lower
|
||||||
|
self.assertLess(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_sharp[1, :]))
|
||||||
|
self.assertGreater(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_sharp[1, :]))
|
||||||
|
|
||||||
|
# smooth peaks get lower, valleys get higher
|
||||||
|
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
|
||||||
|
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
||||||
|
|
||||||
def test_repetition_penalty_dist_process(self):
|
def test_repetition_penalty_dist_process(self):
|
||||||
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
|
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
|
||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
@ -82,6 +118,73 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
||||||
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
||||||
|
|
||||||
|
def test_top_k_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create ramp distribution
|
||||||
|
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
|
||||||
|
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||||
|
|
||||||
|
top_k_warp = TFTopKLogitsWarper(3)
|
||||||
|
|
||||||
|
scores = top_k_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# check that correct tokens are filtered
|
||||||
|
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
|
||||||
|
self.assertListEqual(tf.math.is_inf(scores[1]).numpy().tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||||
|
|
||||||
|
# check special cases
|
||||||
|
length = 5
|
||||||
|
|
||||||
|
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||||
|
top_k_warp_safety_check = TFTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||||
|
|
||||||
|
scores = top_k_warp_safety_check(input_ids, logits)
|
||||||
|
# uniform dist is not changed
|
||||||
|
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
|
||||||
|
|
||||||
|
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||||
|
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||||
|
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
|
||||||
|
|
||||||
|
def test_top_p_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
|
||||||
|
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
|
||||||
|
|
||||||
|
top_p_warp = TFTopPLogitsWarper(0.7)
|
||||||
|
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
|
||||||
|
|
||||||
|
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||||
|
# exp (-inf) => 0
|
||||||
|
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
|
||||||
|
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
|
||||||
|
|
||||||
|
# check edge cases with negative and extreme logits
|
||||||
|
ramp_logits = np.broadcast_to(
|
||||||
|
np.arange(vocab_size, dtype=np.float32)[None, :], (batch_size, vocab_size)
|
||||||
|
).copy() - (vocab_size // 2)
|
||||||
|
|
||||||
|
# make ramp_logits more extreme
|
||||||
|
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||||
|
|
||||||
|
# make sure at least 2 tokens are kept
|
||||||
|
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
|
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
|
||||||
|
# 2.
|
||||||
|
self.assertListEqual(
|
||||||
|
tf.math.reduce_sum(tf.where(filtered_dist != 0.0, 1, 0), axis=-1).numpy().tolist(), [3, 2]
|
||||||
|
)
|
||||||
|
|
||||||
def test_no_repeat_ngram_dist_processor(self):
|
def test_no_repeat_ngram_dist_processor(self):
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
@ -140,13 +243,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
# instantiate all dist processors
|
# instantiate all dist processors
|
||||||
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
|
temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5)
|
||||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||||
|
top_k_warp = TFTopKLogitsWarper(3)
|
||||||
|
top_p_warp = TFTopPLogitsWarper(0.8)
|
||||||
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
||||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||||
|
|
||||||
# no processor list
|
# no processor list
|
||||||
scores = min_dist_proc(input_ids, scores)
|
scores = min_dist_proc(input_ids, scores)
|
||||||
|
scores = temp_dist_warp(input_ids, scores)
|
||||||
scores = rep_penalty_proc(input_ids, scores)
|
scores = rep_penalty_proc(input_ids, scores)
|
||||||
|
scores = top_k_warp(input_ids, scores)
|
||||||
|
scores = top_p_warp(input_ids, scores)
|
||||||
scores = no_repeat_proc(input_ids, scores)
|
scores = no_repeat_proc(input_ids, scores)
|
||||||
scores = no_bad_words_dist_proc(input_ids, scores)
|
scores = no_bad_words_dist_proc(input_ids, scores)
|
||||||
|
|
||||||
@ -154,7 +263,10 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
processor = TFLogitsProcessorList(
|
processor = TFLogitsProcessorList(
|
||||||
[
|
[
|
||||||
min_dist_proc,
|
min_dist_proc,
|
||||||
|
temp_dist_warp,
|
||||||
rep_penalty_proc,
|
rep_penalty_proc,
|
||||||
|
top_k_warp,
|
||||||
|
top_p_warp,
|
||||||
no_repeat_proc,
|
no_repeat_proc,
|
||||||
no_bad_words_dist_proc,
|
no_bad_words_dist_proc,
|
||||||
]
|
]
|
||||||
|
@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
"top_k": 500,
|
"top_k": 500,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
}
|
}
|
||||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
|
with tf.device(":/CPU:0"):
|
||||||
|
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||||
|
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
|
@ -497,9 +497,11 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
"top_k": 500,
|
"top_k": 500,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
}
|
}
|
||||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
|
with tf.device(":/CPU:0"):
|
||||||
|
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||||
|
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
@ -947,7 +947,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
# if bos token id is not defined model needs input_ids
|
# if bos token id is not defined model needs input_ids
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(ValueError):
|
||||||
model.generate(do_sample=True, max_length=5)
|
model.generate(do_sample=True, max_length=5)
|
||||||
# num_return_sequences = 1
|
# num_return_sequences = 1
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||||
|
Loading…
Reference in New Issue
Block a user