TF generate refactor - Sample (#15793)

* Add TF logits wrappers 

* Add sample method

* add tests for TF logit wrappers

* TF generate sample tests now run on CPU

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante 2022-03-02 16:13:54 +00:00 committed by GitHub
parent 96ae92be8c
commit baab5e7cdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 652 additions and 332 deletions

View File

@ -154,18 +154,30 @@ generation.
[[autodoc]] TFLogitsProcessorList
- __call__
[[autodoc]] TFLogitsWarper
- __call__
[[autodoc]] TFTemperatureLogitsWarper
- __call__
[[autodoc]] TFTopPLogitsWarper
- __call__
[[autodoc]] TFTopKLogitsWarper
- __call__
[[autodoc]] TFMinLengthLogitsProcessor
- __call__
[[autodoc]] TFNoBadWordsLogitsProcessor
- __call__
[[autodoc]] TFNoRepeatNGramLogitsProcessor
- __call__
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] FlaxLogitsProcessor
- __call__

View File

@ -1656,10 +1656,14 @@ if is_tf_available():
_import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
@ -3706,10 +3710,14 @@ if TYPE_CHECKING:
from .generation_tf_logits_process import (
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback

View File

@ -94,7 +94,7 @@ class FlaxLogitsProcessorList(list):
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
[`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
Args:
temperature (`float`):
@ -114,7 +114,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
[`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
Args:
top_p (`float`):
@ -155,7 +155,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
[`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (`int`):

View File

@ -326,7 +326,7 @@ class FlaxGenerationMixin:
raise NotImplementedError("`Beam sampling is currently not implemented.")
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> FlaxLogitsProcessorList:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]

View File

@ -58,6 +58,17 @@ class TFLogitsProcessor:
)
class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
"""TF method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class TFLogitsProcessorList(list):
"""
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
@ -81,6 +92,109 @@ class TFLogitsProcessorList(list):
return scores
class TFTemperatureLogitsWarper(TFLogitsWarper):
r"""
[`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
scores = scores / self.temperature
return scores
class TFTopKLogitsWarper(TFLogitsWarper):
r"""
[`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
return next_scores
class TFTopPLogitsWarper(TFLogitsWarper):
"""
[`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
Args:
top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
mask_scores = tf.fill(scores.shape, self.filter_value)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
score_mask = cumulative_probs < self.top_p
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
# Ensure min tokens to keep
score_mask = tf.concat(
(
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
score_mask[:, self.min_tokens_to_keep :],
),
axis=-1,
)
# Mask the values that do not fit the criteria
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
# Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
return next_scores
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.

View File

@ -28,6 +28,9 @@ from .generation_tf_logits_process import (
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import set_tensor_by_indices_to_value, shape_list
from .utils import logging
@ -558,9 +561,7 @@ class TFGenerationMixin:
num_beams = num_beams if num_beams is not None else self.config.num_beams
do_sample = do_sample if do_sample is not None else self.config.do_sample
is_greedy_gen_mode = num_beams == 1 and do_sample is False
if is_greedy_gen_mode:
if num_beams == 1:
return self._generate(
input_ids=input_ids,
max_length=max_length,
@ -790,304 +791,34 @@ class TFGenerationMixin:
cur_len < max_length
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
if num_beams == 1:
return self._generate_no_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
else:
return self._generate_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
return_dict_in_generate,
**kwargs
) -> Union[TFGreedySearchOutput, TFSampleOutput, tf.Tensor]:
"""
Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated
independently.
"""
# length of generated sentences / unfinished sentences
unfinished_sents = tf.ones_like(input_ids[:, 0])
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
# defined for encoder-decoder models, None for decoder-only models
past = encoder_outputs
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if self.config.is_encoder_decoder:
encoder_attentions = (
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
)
encoder_hidden_states = (
kwargs["encoder_hidden_states"]
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
else None
)
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids,
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
**kwargs,
)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=kwargs["output_attentions"],
output_hidden_states=kwargs["output_hidden_states"],
)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if kwargs["output_scores"]:
scores += (next_token_logits,)
if kwargs["output_attentions"]:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if kwargs["output_hidden_states"]:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, eos_token_indices_mask, -float("inf")
)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample
next_token = tf.squeeze(
tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
)
else:
# Greedy decoding
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
# update generations and finished sentences
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
# add token and increase length by one
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
cur_len = cur_len + 1
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
)
sent_lengths = (
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
# stop when there is a </s> in each sentence, or if we exceed the maximum length
if tf.math.reduce_max(unfinished_sents) == 0:
break
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
# if there are different sentences lengths in the batch, some batches have to be padded
min_sent_length = tf.math.reduce_min(sent_lengths)
max_sent_length = tf.math.reduce_max(sent_lengths)
if min_sent_length != max_sent_length:
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
# finished sents are filled with pad_token
padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
# create length masks for tf.where operation
broad_casted_sent_lengths = tf.broadcast_to(
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
)
broad_casted_range = tf.transpose(
tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size])
)
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
else:
decoded = input_ids
if return_dict_in_generate:
if do_sample:
if self.config.is_encoder_decoder:
return TFSampleEncoderDecoderOutput(
sequences=decoded,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFSampleDecoderOnlyOutput(
sequences=decoded,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
if self.config.is_encoder_decoder:
return TFGreedySearchEncoderDecoderOutput(
sequences=decoded,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFGreedySearchDecoderOnlyOutput(
sequences=decoded,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return decoded
return self._generate_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
def _generate_beam_search(
self,
@ -1761,11 +1492,6 @@ class TFGenerationMixin:
input_ids, return_dict_in_generate, model_kwargs
)
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger
# refactor of all generation models in TF. `past` should be
# optional everywhere and not be set equal to encoder_outputs
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
@ -1787,6 +1513,7 @@ class TFGenerationMixin:
# 5. determine generation mode
# TODO(Matt, Joao, Patrick) - add more use cases here
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
# 6. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
@ -1804,6 +1531,10 @@ class TFGenerationMixin:
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 8. run greedy search
return self.greedy_search(
input_ids,
@ -1816,6 +1547,35 @@ class TFGenerationMixin:
**model_kwargs,
)
elif is_sample_gen_mode:
# 8. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
# 9. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 10. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
def _prepare_attention_mask_for_generation(
@ -1908,6 +1668,36 @@ class TFGenerationMixin:
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@staticmethod
def _expand_inputs_for_generation(
input_ids: tf.Tensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[ModelOutput] = None,
**model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]:
expanded_return_idx = tf.reshape(
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1)
)
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx, axis=0)
if attention_mask is not None:
model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx, axis=0)
if is_encoder_decoder:
if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = tf.gather(
encoder_outputs.last_hidden_state, expanded_return_idx, axis=0
)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
# TODO(Patrick) - adapt this function when making `generate` more flexible
# for all kinds of input types
@ -1956,6 +1746,34 @@ class TFGenerationMixin:
return model_kwargs
def _get_logits_warper(
self,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
instances used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = TFLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(TFTemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(TFTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
if top_p is not None and top_p < 1.0:
warpers.append(TFTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
return warpers
def _get_logits_processor(
self,
repetition_penalty: float,
@ -2029,8 +1847,8 @@ class TFGenerationMixin:
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Additional model specific keyword arguments will be forwarded to the `call` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`],
@ -2043,13 +1861,13 @@ class TFGenerationMixin:
```python
>>> from transformers import (
... TFAutoTokenizer,
... AutoTokenizer,
... TFAutoModelForCausalLM,
... TFLogitsProcessorList,
... TFMinLengthLogitsProcessor,
... )
>>> tokenizer = TFAutoTokenizer.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
@ -2195,6 +2013,230 @@ class TFGenerationMixin:
else:
return input_ids
def sample(
self,
input_ids: tf.Tensor,
logits_processor: Optional[TFLogitsProcessorList] = None,
logits_warper: Optional[TFLogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[TFSampleOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head using multinomial sampling.
Parameters:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`TFLogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`TFLogitsProcessorList`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`]
used to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an
encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation_tf_utils.TFSampleDecoderOnlyOutput`], [`~generation_tf_utils.TFSampleEncoderDecoderOutput`]
or `tf.Tensor`: A `tf.Tensor` containing the generated tokens (default behaviour) or a
[`~generation_tf_utils.TFSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_tf_utils.TFSampleEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... AutoTokenizer,
... TFAutoModelForCausalLM,
... TFLogitsProcessorList,
... TFMinLengthLogitsProcessor,
... TFTopKLogitsWarper,
... TFTemperatureLogitsWarper,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids
>>> # instantiate logits processors
>>> logits_processor = TFLogitsProcessorList(
... [
... TFMinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> # instantiate logits processors
>>> logits_warper = TFLogitsProcessorList(
... [
... TFTopKLogitsWarper(50),
... TFTemperatureLogitsWarper(0.7),
... ]
... )
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]
while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample
next_tokens = tf.squeeze(
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1
)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
eos_in_sents = next_tokens == eos_token_id
# if sentence is unfinished and the token to add is eos
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
)
# unfinished_sequences is set to zero if eos in sentence
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
# stop when each sentence is finished, or if we exceed the maximum length
if tf.math.reduce_max(unfinished_sequences) == 0:
break
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return TFSampleEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFSampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
@ -2292,7 +2334,6 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
if top_p < 1.0:
sorted_indices = tf.argsort(logits, direction="DESCENDING")
sorted_logits = tf.gather(

View File

@ -556,8 +556,8 @@ class GenerationMixin:
input_ids: torch.LongTensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: torch.LongTensor = None,
encoder_outputs: ModelOutput = None,
attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[ModelOutput] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = (
@ -617,11 +617,11 @@ class GenerationMixin:
def _get_logits_warper(
self,
top_k: int = None,
top_p: float = None,
typical_p: float = None,
temperature: float = None,
num_beams: int = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances

View File

@ -31,6 +31,13 @@ class TFLogitsProcessorList(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFMinLengthLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
@ -59,6 +66,27 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFTemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFTopKLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFTopPLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
def tf_top_k_top_p_filtering(*args, **kwargs):
requires_backends(tf_top_k_top_p_filtering, ["tf"])

View File

@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return scores
def test_min_lenght_dist_processor(self):
def test_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0

View File

@ -16,6 +16,8 @@
import unittest
import numpy as np
from transformers import is_tf_available
from transformers.testing_utils import require_tf
@ -29,6 +31,9 @@ if is_tf_available():
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from transformers.tf_utils import set_tensor_by_indices_to_value
@ -38,7 +43,7 @@ if is_tf_available():
@require_tf
class TFLogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
scores = np.ones((batch_size, length), dtype=np.float32) / length
return scores
def test_min_length_dist_processor(self):
@ -60,6 +65,37 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = tf.nn.softmax(scores, axis=-1)
temp_dist_warper_sharper = TFTemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TFTemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
# uniform distribution stays uniform
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
tf.debugging.assert_near(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)
# sharp peaks get higher, valleys get lower
self.assertLess(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_sharp[1, :]))
self.assertGreater(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_sharp[1, :]))
# smooth peaks get lower, valleys get higher
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
def test_repetition_penalty_dist_process(self):
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
vocab_size = 10
@ -82,6 +118,73 @@ class TFLogitsProcessorTest(unittest.TestCase):
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = TFTopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
# check that correct tokens are filtered
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(tf.math.is_inf(scores[1]).numpy().tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
top_k_warp_safety_check = TFTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
scores = top_k_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
top_p_warp = TFTopPLogitsWarper(0.7)
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
# check edge cases with negative and extreme logits
ramp_logits = np.broadcast_to(
np.arange(vocab_size, dtype=np.float32)[None, :], (batch_size, vocab_size)
).copy() - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# 2.
self.assertListEqual(
tf.math.reduce_sum(tf.where(filtered_dist != 0.0, 1, 0), axis=-1).numpy().tolist(), [3, 2]
)
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
@ -140,13 +243,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
# instantiate all dist processors
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TFTopKLogitsWarper(3)
top_p_warp = TFTopPLogitsWarper(0.8)
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list
scores = min_dist_proc(input_ids, scores)
scores = temp_dist_warp(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores)
@ -154,7 +263,10 @@ class TFLogitsProcessorTest(unittest.TestCase):
processor = TFLogitsProcessorList(
[
min_dist_proc,
temp_dist_warp,
rep_penalty_proc,
top_k_warp,
top_p_warp,
no_repeat_proc,
no_bad_words_dist_proc,
]

View File

@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [

View File

@ -497,9 +497,11 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

View File

@ -947,7 +947,7 @@ class TFModelTesterMixin:
if config.bos_token_id is None:
# if bos token id is not defined model needs input_ids
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True))