mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
TF generate refactor - Sample (#15793)
* Add TF logits wrappers * Add sample method * add tests for TF logit wrappers * TF generate sample tests now run on CPU Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
96ae92be8c
commit
baab5e7cdf
@ -154,18 +154,30 @@ generation.
|
||||
[[autodoc]] TFLogitsProcessorList
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFMinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFNoBadWordsLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
||||
[[autodoc]] TFNoRepeatNGramLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
||||
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
||||
[[autodoc]] FlaxLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
@ -1656,10 +1656,14 @@ if is_tf_available():
|
||||
_import_structure["generation_tf_logits_process"] = [
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
"TFMinLengthLogitsProcessor",
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
]
|
||||
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
||||
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
|
||||
@ -3706,10 +3710,14 @@ if TYPE_CHECKING:
|
||||
from .generation_tf_logits_process import (
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFLogitsWarper,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from .generation_tf_utils import tf_top_k_top_p_filtering
|
||||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
|
@ -94,7 +94,7 @@ class FlaxLogitsProcessorList(list):
|
||||
|
||||
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
[`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
|
||||
Args:
|
||||
temperature (`float`):
|
||||
@ -114,7 +114,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||
|
||||
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
[`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
|
||||
Args:
|
||||
top_p (`float`):
|
||||
@ -155,7 +155,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||
|
||||
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
[`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
|
||||
Args:
|
||||
top_k (`int`):
|
||||
|
@ -326,7 +326,7 @@ class FlaxGenerationMixin:
|
||||
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None
|
||||
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
|
||||
|
@ -58,6 +58,17 @@ class TFLogitsProcessor:
|
||||
)
|
||||
|
||||
|
||||
class TFLogitsWarper:
|
||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||
|
||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
"""TF method for warping logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class TFLogitsProcessorList(list):
|
||||
"""
|
||||
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
|
||||
@ -81,6 +92,109 @@ class TFLogitsProcessorList(list):
|
||||
return scores
|
||||
|
||||
|
||||
class TFTemperatureLogitsWarper(TFLogitsWarper):
|
||||
r"""
|
||||
[`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
|
||||
Args:
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class TFTopKLogitsWarper(TFLogitsWarper):
|
||||
r"""
|
||||
[`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
||||
|
||||
self.top_k = top_k
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
||||
# Boolean mask containing all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
|
||||
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
|
||||
return next_scores
|
||||
|
||||
|
||||
class TFTopPLogitsWarper(TFLogitsWarper):
|
||||
"""
|
||||
[`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
|
||||
|
||||
Args:
|
||||
top_p (`float`):
|
||||
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
|
||||
for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
|
||||
|
||||
mask_scores = tf.fill(scores.shape, self.filter_value)
|
||||
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
|
||||
score_mask = cumulative_probs < self.top_p
|
||||
|
||||
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
|
||||
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
|
||||
|
||||
# Ensure min tokens to keep
|
||||
score_mask = tf.concat(
|
||||
(
|
||||
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
|
||||
score_mask[:, self.min_tokens_to_keep :],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# Mask the values that do not fit the criteria
|
||||
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
|
||||
|
||||
# Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
|
||||
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
|
||||
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
|
||||
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
|
||||
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
|
||||
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
|
||||
|
||||
return next_scores
|
||||
|
||||
|
||||
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
||||
r"""
|
||||
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
|
||||
|
@ -28,6 +28,9 @@ from .generation_tf_logits_process import (
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from .tf_utils import set_tensor_by_indices_to_value, shape_list
|
||||
from .utils import logging
|
||||
@ -558,9 +561,7 @@ class TFGenerationMixin:
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
|
||||
is_greedy_gen_mode = num_beams == 1 and do_sample is False
|
||||
|
||||
if is_greedy_gen_mode:
|
||||
if num_beams == 1:
|
||||
return self._generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
@ -790,304 +791,34 @@ class TFGenerationMixin:
|
||||
cur_len < max_length
|
||||
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
|
||||
|
||||
if num_beams == 1:
|
||||
return self._generate_no_beam_search(
|
||||
input_ids,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._generate_beam_search(
|
||||
input_ids,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
early_stopping=early_stopping,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def _generate_no_beam_search(
|
||||
self,
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bad_words_ids,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
batch_size,
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
return_dict_in_generate,
|
||||
**kwargs
|
||||
) -> Union[TFGreedySearchOutput, TFSampleOutput, tf.Tensor]:
|
||||
"""
|
||||
Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated
|
||||
independently.
|
||||
"""
|
||||
|
||||
# length of generated sentences / unfinished sentences
|
||||
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
||||
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
||||
|
||||
# defined for encoder-decoder models, None for decoder-only models
|
||||
past = encoder_outputs
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
||||
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if self.config.is_encoder_decoder:
|
||||
encoder_attentions = (
|
||||
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
kwargs["encoder_hidden_states"]
|
||||
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
|
||||
else None
|
||||
)
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=kwargs["output_attentions"],
|
||||
output_hidden_states=kwargs["output_hidden_states"],
|
||||
)
|
||||
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if kwargs["output_scores"]:
|
||||
scores += (next_token_logits,)
|
||||
if kwargs["output_attentions"]:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if kwargs["output_hidden_states"]:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(
|
||||
input_ids, next_token_logits, repetition_penalty
|
||||
)
|
||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||
# create banned_tokens boolean mask
|
||||
banned_tokens_indices_mask = []
|
||||
for banned_tokens_slice in banned_tokens:
|
||||
banned_tokens_indices_mask.append(
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
next_token_logits = set_tensor_by_indices_to_value(
|
||||
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
)
|
||||
|
||||
if bad_words_ids is not None:
|
||||
# calculate a list of banned tokens according to bad words
|
||||
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
||||
|
||||
banned_tokens_indices_mask = []
|
||||
for banned_tokens_slice in banned_tokens:
|
||||
banned_tokens_indices_mask.append(
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
next_token_logits = set_tensor_by_indices_to_value(
|
||||
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
# create eos_token_id boolean mask
|
||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
|
||||
next_token_logits = set_tensor_by_indices_to_value(
|
||||
next_token_logits, eos_token_indices_mask, -float("inf")
|
||||
)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
# Sample
|
||||
next_token = tf.squeeze(
|
||||
tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
|
||||
)
|
||||
else:
|
||||
# Greedy decoding
|
||||
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
||||
|
||||
# update generations and finished sentences
|
||||
if eos_token_id is not None:
|
||||
# pad finished sentences if eos_token_id exist
|
||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||
else:
|
||||
tokens_to_add = next_token
|
||||
|
||||
# add token and increase length by one
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
|
||||
)
|
||||
sent_lengths = (
|
||||
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
|
||||
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
|
||||
)
|
||||
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
||||
if tf.math.reduce_max(unfinished_sents) == 0:
|
||||
break
|
||||
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = tf.concat(
|
||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
||||
)
|
||||
|
||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||
min_sent_length = tf.math.reduce_min(sent_lengths)
|
||||
max_sent_length = tf.math.reduce_max(sent_lengths)
|
||||
if min_sent_length != max_sent_length:
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
||||
# finished sents are filled with pad_token
|
||||
padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
|
||||
|
||||
# create length masks for tf.where operation
|
||||
broad_casted_sent_lengths = tf.broadcast_to(
|
||||
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
|
||||
)
|
||||
broad_casted_range = tf.transpose(
|
||||
tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size])
|
||||
)
|
||||
|
||||
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
|
||||
else:
|
||||
decoded = input_ids
|
||||
|
||||
if return_dict_in_generate:
|
||||
if do_sample:
|
||||
if self.config.is_encoder_decoder:
|
||||
return TFSampleEncoderDecoderOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFSampleDecoderOnlyOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
if self.config.is_encoder_decoder:
|
||||
return TFGreedySearchEncoderDecoderOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFGreedySearchDecoderOnlyOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return decoded
|
||||
return self._generate_beam_search(
|
||||
input_ids,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
early_stopping=early_stopping,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def _generate_beam_search(
|
||||
self,
|
||||
@ -1761,11 +1492,6 @@ class TFGenerationMixin:
|
||||
input_ids, return_dict_in_generate, model_kwargs
|
||||
)
|
||||
|
||||
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger
|
||||
# refactor of all generation models in TF. `past` should be
|
||||
# optional everywhere and not be set equal to encoder_outputs
|
||||
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||
|
||||
# 4. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
||||
@ -1787,6 +1513,7 @@ class TFGenerationMixin:
|
||||
# 5. determine generation mode
|
||||
# TODO(Matt, Joao, Patrick) - add more use cases here
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
|
||||
# 6. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
@ -1804,6 +1531,10 @@ class TFGenerationMixin:
|
||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||
)
|
||||
|
||||
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
|
||||
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
|
||||
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||
|
||||
# 8. run greedy search
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
@ -1816,6 +1547,35 @@ class TFGenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# 8. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
|
||||
# 9. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids,
|
||||
expand_size=num_return_sequences,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
|
||||
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
|
||||
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||
|
||||
# 10. run sample
|
||||
return self.sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
@ -1908,6 +1668,36 @@ class TFGenerationMixin:
|
||||
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _expand_inputs_for_generation(
|
||||
input_ids: tf.Tensor,
|
||||
expand_size: int = 1,
|
||||
is_encoder_decoder: bool = False,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_outputs: Optional[ModelOutput] = None,
|
||||
**model_kwargs,
|
||||
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
||||
expanded_return_idx = tf.reshape(
|
||||
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1)
|
||||
)
|
||||
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
|
||||
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx, axis=0)
|
||||
|
||||
if attention_mask is not None:
|
||||
model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx, axis=0)
|
||||
|
||||
if is_encoder_decoder:
|
||||
if encoder_outputs is None:
|
||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||
encoder_outputs["last_hidden_state"] = tf.gather(
|
||||
encoder_outputs.last_hidden_state, expanded_return_idx, axis=0
|
||||
)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
return input_ids, model_kwargs
|
||||
|
||||
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
|
||||
# TODO(Patrick) - adapt this function when making `generate` more flexible
|
||||
# for all kinds of input types
|
||||
@ -1956,6 +1746,34 @@ class TFGenerationMixin:
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _get_logits_warper(
|
||||
self,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> TFLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
|
||||
instances used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# init warp parameters
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
top_p = top_p if top_p is not None else self.config.top_p
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
# instantiate warpers list
|
||||
warpers = TFLogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if temperature is not None and temperature != 1.0:
|
||||
warpers.append(TFTemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
warpers.append(TFTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(TFTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
repetition_penalty: float,
|
||||
@ -2029,8 +1847,8 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
Additional model specific keyword arguments will be forwarded to the `call` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
|
||||
@ -2043,13 +1861,13 @@ class TFGenerationMixin:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... TFAutoTokenizer,
|
||||
... AutoTokenizer,
|
||||
... TFAutoModelForCausalLM,
|
||||
... TFLogitsProcessorList,
|
||||
... TFMinLengthLogitsProcessor,
|
||||
... )
|
||||
|
||||
>>> tokenizer = TFAutoTokenizer.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
||||
@ -2195,6 +2013,230 @@ class TFGenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def sample(
|
||||
self,
|
||||
input_ids: tf.Tensor,
|
||||
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||
logits_warper: Optional[TFLogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[TFSampleOutput, tf.Tensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using multinomial sampling.
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
logits_processor (`TFLogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
logits_warper (`TFLogitsProcessorList`, *optional*):
|
||||
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`]
|
||||
used to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more details.
|
||||
output_scores (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an
|
||||
encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation_tf_utils.TFSampleDecoderOnlyOutput`], [`~generation_tf_utils.TFSampleEncoderDecoderOutput`]
|
||||
or `tf.Tensor`: A `tf.Tensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation_tf_utils.TFSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation_tf_utils.TFSampleEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... TFAutoModelForCausalLM,
|
||||
... TFLogitsProcessorList,
|
||||
... TFMinLengthLogitsProcessor,
|
||||
... TFTopKLogitsWarper,
|
||||
... TFTemperatureLogitsWarper,
|
||||
... )
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
||||
>>> model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
>>> input_prompt = "Today is a beautiful day, and"
|
||||
>>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids
|
||||
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = TFLogitsProcessorList(
|
||||
... [
|
||||
... TFMinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
|
||||
... ]
|
||||
... )
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_warper = TFLogitsProcessorList(
|
||||
... [
|
||||
... TFTopKLogitsWarper(50),
|
||||
... TFTemperatureLogitsWarper(0.7),
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
```"""
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
||||
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
|
||||
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
|
||||
# Remove the following lines when updating all encoder-decoder models
|
||||
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
while cur_len < max_length:
|
||||
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
|
||||
# in all models
|
||||
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# sample
|
||||
next_tokens = tf.squeeze(
|
||||
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1
|
||||
)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = next_tokens == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
|
||||
)
|
||||
|
||||
# unfinished_sequences is set to zero if eos in sentence
|
||||
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if tf.math.reduce_max(unfinished_sequences) == 0:
|
||||
break
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return TFSampleEncoderDecoderOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFSampleDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
# create logit penalties for already seen input_ids
|
||||
@ -2292,7 +2334,6 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
||||
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
||||
sorted_logits = tf.gather(
|
||||
|
@ -556,8 +556,8 @@ class GenerationMixin:
|
||||
input_ids: torch.LongTensor,
|
||||
expand_size: int = 1,
|
||||
is_encoder_decoder: bool = False,
|
||||
attention_mask: torch.LongTensor = None,
|
||||
encoder_outputs: ModelOutput = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_outputs: Optional[ModelOutput] = None,
|
||||
**model_kwargs,
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
expanded_return_idx = (
|
||||
@ -617,11 +617,11 @@ class GenerationMixin:
|
||||
|
||||
def _get_logits_warper(
|
||||
self,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
typical_p: float = None,
|
||||
temperature: float = None,
|
||||
num_beams: int = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
|
@ -31,6 +31,13 @@ class TFLogitsProcessorList(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFMinLengthLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
@ -59,6 +66,27 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFTemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFTopKLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFTopPLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
def tf_top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
||||
|
||||
|
@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||
return scores
|
||||
|
||||
def test_min_lenght_dist_processor(self):
|
||||
def test_min_length_dist_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
@ -29,6 +31,9 @@ if is_tf_available():
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from transformers.tf_utils import set_tensor_by_indices_to_value
|
||||
|
||||
@ -38,7 +43,7 @@ if is_tf_available():
|
||||
@require_tf
|
||||
class TFLogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
||||
scores = np.ones((batch_size, length), dtype=np.float32) / length
|
||||
return scores
|
||||
|
||||
def test_min_length_dist_processor(self):
|
||||
@ -60,6 +65,37 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
input_ids = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = tf.nn.softmax(scores, axis=-1)
|
||||
|
||||
temp_dist_warper_sharper = TFTemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = TFTemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
|
||||
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
|
||||
tf.debugging.assert_near(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)
|
||||
|
||||
# sharp peaks get higher, valleys get lower
|
||||
self.assertLess(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_sharp[1, :]))
|
||||
self.assertGreater(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_sharp[1, :]))
|
||||
|
||||
# smooth peaks get lower, valleys get higher
|
||||
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
|
||||
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
||||
|
||||
def test_repetition_penalty_dist_process(self):
|
||||
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
|
||||
vocab_size = 10
|
||||
@ -82,6 +118,73 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create ramp distribution
|
||||
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
|
||||
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||
|
||||
top_k_warp = TFTopKLogitsWarper(3)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
|
||||
self.assertListEqual(tf.math.is_inf(scores[1]).numpy().tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||
|
||||
# check special cases
|
||||
length = 5
|
||||
|
||||
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||
top_k_warp_safety_check = TFTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
scores = top_k_warp_safety_check(input_ids, logits)
|
||||
# uniform dist is not changed
|
||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
|
||||
|
||||
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
|
||||
|
||||
def test_top_p_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
|
||||
|
||||
top_p_warp = TFTopPLogitsWarper(0.7)
|
||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
|
||||
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = np.broadcast_to(
|
||||
np.arange(vocab_size, dtype=np.float32)[None, :], (batch_size, vocab_size)
|
||||
).copy() - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
|
||||
# 2.
|
||||
self.assertListEqual(
|
||||
tf.math.reduce_sum(tf.where(filtered_dist != 0.0, 1, 0), axis=-1).numpy().tolist(), [3, 2]
|
||||
)
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
@ -140,13 +243,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
# instantiate all dist processors
|
||||
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5)
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TFTopKLogitsWarper(3)
|
||||
top_p_warp = TFTopPLogitsWarper(0.8)
|
||||
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||
|
||||
# no processor list
|
||||
scores = min_dist_proc(input_ids, scores)
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = rep_penalty_proc(input_ids, scores)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
scores = no_repeat_proc(input_ids, scores)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores)
|
||||
|
||||
@ -154,7 +263,10 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
processor = TFLogitsProcessorList(
|
||||
[
|
||||
min_dist_proc,
|
||||
temp_dist_warp,
|
||||
rep_penalty_proc,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
no_repeat_proc,
|
||||
no_bad_words_dist_proc,
|
||||
]
|
||||
|
@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = [
|
||||
|
@ -497,9 +497,11 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
|
@ -947,7 +947,7 @@ class TFModelTesterMixin:
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_ids
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||
|
Loading…
Reference in New Issue
Block a user