diff --git a/docs/source/internal/generation_utils.mdx b/docs/source/internal/generation_utils.mdx index 5ee321c0a44..089dcf3b9c2 100644 --- a/docs/source/internal/generation_utils.mdx +++ b/docs/source/internal/generation_utils.mdx @@ -154,18 +154,30 @@ generation. [[autodoc]] TFLogitsProcessorList - __call__ +[[autodoc]] TFLogitsWarper + - __call__ + +[[autodoc]] TFTemperatureLogitsWarper + - __call__ + +[[autodoc]] TFTopPLogitsWarper + - __call__ + +[[autodoc]] TFTopKLogitsWarper + - __call__ + [[autodoc]] TFMinLengthLogitsProcessor - __call__ [[autodoc]] TFNoBadWordsLogitsProcessor - __call__ - + [[autodoc]] TFNoRepeatNGramLogitsProcessor - __call__ - + [[autodoc]] TFRepetitionPenaltyLogitsProcessor - __call__ - + [[autodoc]] FlaxLogitsProcessor - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 29bb7243df6..ccb99d0a76f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1656,10 +1656,14 @@ if is_tf_available(): _import_structure["generation_tf_logits_process"] = [ "TFLogitsProcessor", "TFLogitsProcessorList", + "TFLogitsWarper", "TFMinLengthLogitsProcessor", "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFTemperatureLogitsWarper", + "TFTopKLogitsWarper", + "TFTopPLogitsWarper", ] _import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] @@ -3706,10 +3710,14 @@ if TYPE_CHECKING: from .generation_tf_logits_process import ( TFLogitsProcessor, TFLogitsProcessorList, + TFLogitsWarper, TFMinLengthLogitsProcessor, TFNoBadWordsLogitsProcessor, TFNoRepeatNGramLogitsProcessor, TFRepetitionPenaltyLogitsProcessor, + TFTemperatureLogitsWarper, + TFTopKLogitsWarper, + TFTopPLogitsWarper, ) from .generation_tf_utils import tf_top_k_top_p_filtering from .keras_callbacks import KerasMetricCallback, PushToHubCallback diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index 76a09ed012d..e4ebc343e63 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -94,7 +94,7 @@ class FlaxLogitsProcessorList(list): class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): r""" - [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + [`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution). Args: temperature (`float`): @@ -114,7 +114,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): class FlaxTopPLogitsWarper(FlaxLogitsWarper): """ - [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + [`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Args: top_p (`float`): @@ -155,7 +155,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): class FlaxTopKLogitsWarper(FlaxLogitsWarper): r""" - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + [`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Args: top_k (`int`): diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 37015f76107..a9f76d738e9 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -326,7 +326,7 @@ class FlaxGenerationMixin: raise NotImplementedError("`Beam sampling is currently not implemented.") def _get_logits_warper( - self, top_k: int = None, top_p: float = None, temperature: float = None + self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None ) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`] diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 74a61768566..098d76ef27a 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -58,6 +58,17 @@ class TFLogitsProcessor: ) +class TFLogitsWarper: + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + """TF method for warping logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + class TFLogitsProcessorList(list): """ This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor. @@ -81,6 +92,109 @@ class TFLogitsProcessorList(list): return scores +class TFTemperatureLogitsWarper(TFLogitsWarper): + r""" + [`TFLogitsWarper`] for temperature (exponential scaling output probability distribution). + + Args: + temperature (`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature: float): + if not isinstance(temperature, float) or not (temperature > 0): + raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + + self.temperature = temperature + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + scores = scores / self.temperature + return scores + + +class TFTopKLogitsWarper(TFLogitsWarper): + r""" + [`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = top_k + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check + # Boolean mask containing all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:] + next_scores = tf.where(indices_to_remove, self.filter_value, scores) + return next_scores + + +class TFTopPLogitsWarper(TFLogitsWarper): + """ + [`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off. + + Args: + top_p (`float`): + If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept + for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1]) + + mask_scores = tf.fill(scores.shape, self.filter_value) + cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1) + score_mask = cumulative_probs < self.top_p + + # Also include the token that is higher than top_p (the first false = shift and insert a True on the left) + score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1) + + # Ensure min tokens to keep + score_mask = tf.concat( + ( + tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool), + score_mask[:, self.min_tokens_to_keep :], + ), + axis=-1, + ) + + # Mask the values that do not fit the criteria + topk_next_scores = tf.where(score_mask, topk_scores, mask_scores) + + # Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size) + # to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we + # can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`) + scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]]) + scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1) + next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape) + + return next_scores + + class TFMinLengthLogitsProcessor(TFLogitsProcessor): r""" [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0. diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index b8d4746fe2e..31f019e1b8e 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -28,6 +28,9 @@ from .generation_tf_logits_process import ( TFNoBadWordsLogitsProcessor, TFNoRepeatNGramLogitsProcessor, TFRepetitionPenaltyLogitsProcessor, + TFTemperatureLogitsWarper, + TFTopKLogitsWarper, + TFTopPLogitsWarper, ) from .tf_utils import set_tensor_by_indices_to_value, shape_list from .utils import logging @@ -558,9 +561,7 @@ class TFGenerationMixin: num_beams = num_beams if num_beams is not None else self.config.num_beams do_sample = do_sample if do_sample is not None else self.config.do_sample - is_greedy_gen_mode = num_beams == 1 and do_sample is False - - if is_greedy_gen_mode: + if num_beams == 1: return self._generate( input_ids=input_ids, max_length=max_length, @@ -790,304 +791,34 @@ class TFGenerationMixin: cur_len < max_length ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" - if num_beams == 1: - return self._generate_no_beam_search( - input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - batch_size=effective_batch_size, - vocab_size=vocab_size, - encoder_outputs=encoder_outputs, - attention_mask=attention_mask, - use_cache=use_cache, - return_dict_in_generate=return_dict_in_generate, - **model_kwargs, - ) - else: - return self._generate_beam_search( - input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - early_stopping=early_stopping, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - batch_size=effective_batch_size, - num_return_sequences=num_return_sequences, - length_penalty=length_penalty, - num_beams=num_beams, - vocab_size=vocab_size, - encoder_outputs=encoder_outputs, - attention_mask=attention_mask, - use_cache=use_cache, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, - return_dict_in_generate=return_dict_in_generate, - **model_kwargs, - ) - - def _generate_no_beam_search( - self, - input_ids, - cur_len, - max_length, - min_length, - do_sample, - temperature, - top_k, - top_p, - repetition_penalty, - no_repeat_ngram_size, - bad_words_ids, - pad_token_id, - eos_token_id, - batch_size, - vocab_size, - encoder_outputs, - attention_mask, - use_cache, - return_dict_in_generate, - **kwargs - ) -> Union[TFGreedySearchOutput, TFSampleOutput, tf.Tensor]: - """ - Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated - independently. - """ - - # length of generated sentences / unfinished sentences - unfinished_sents = tf.ones_like(input_ids[:, 0]) - sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length - - # defined for encoder-decoder models, None for decoder-only models - past = encoder_outputs - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None - decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None - cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None - decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if self.config.is_encoder_decoder: - encoder_attentions = ( - kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None - ) - encoder_hidden_states = ( - kwargs["encoder_hidden_states"] - if (return_dict_in_generate and kwargs["encoder_hidden_states"]) - else None - ) - - while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation( - input_ids, - past=past, - attention_mask=attention_mask, - use_cache=use_cache, - **kwargs, - ) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=kwargs["output_attentions"], - output_hidden_states=kwargs["output_hidden_states"], - ) - next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if kwargs["output_scores"]: - scores += (next_token_logits,) - if kwargs["output_attentions"]: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if kwargs["output_hidden_states"]: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # if model has past, then set the past variable to speed up decoding - if self._use_cache(outputs, use_cache): - past = outputs[1] - - # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) - if repetition_penalty != 1.0: - next_token_logits_penalties = _create_next_token_logits_penalties( - input_ids, next_token_logits, repetition_penalty - ) - next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) - - if no_repeat_ngram_size > 0: - # calculate a list of banned tokens to prevent repetitively generating the same ngrams - # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) - # create banned_tokens boolean mask - banned_tokens_indices_mask = [] - for banned_tokens_slice in banned_tokens: - banned_tokens_indices_mask.append( - [True if token in banned_tokens_slice else False for token in range(vocab_size)] - ) - - next_token_logits = set_tensor_by_indices_to_value( - next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") - ) - - if bad_words_ids is not None: - # calculate a list of banned tokens according to bad words - banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) - - banned_tokens_indices_mask = [] - for banned_tokens_slice in banned_tokens: - banned_tokens_indices_mask.append( - [True if token in banned_tokens_slice else False for token in range(vocab_size)] - ) - - next_token_logits = set_tensor_by_indices_to_value( - next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") - ) - - # set eos token prob to zero if min_length is not reached - if eos_token_id is not None and cur_len < min_length: - # create eos_token_id boolean mask - is_token_logit_eos_token = tf.convert_to_tensor( - [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool - ) - eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) - - next_token_logits = set_tensor_by_indices_to_value( - next_token_logits, eos_token_indices_mask, -float("inf") - ) - - if do_sample: - # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: - next_token_logits = next_token_logits / temperature - # Top-p/top-k filtering - next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - # Sample - next_token = tf.squeeze( - tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1 - ) - else: - # Greedy decoding - next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32) - - # update generations and finished sentences - if eos_token_id is not None: - # pad finished sentences if eos_token_id exist - tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) - else: - tokens_to_add = next_token - - # add token and increase length by one - input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1) - cur_len = cur_len + 1 - - if eos_token_id is not None: - eos_in_sents = tokens_to_add == eos_token_id - # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length - is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( - unfinished_sents, tf.cast(eos_in_sents, tf.int32) - ) - sent_lengths = ( - sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) - + cur_len * is_sents_unfinished_and_token_to_add_is_eos - ) - - # unfinished_sents is set to zero if eos in sentence - unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos - - # stop when there is a in each sentence, or if we exceed the maximum length - if tf.math.reduce_max(unfinished_sents) == 0: - break - - # extend attention_mask for new generated input if only decoder - if self.config.is_encoder_decoder is False: - attention_mask = tf.concat( - [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 - ) - - # if there are different sentences lengths in the batch, some batches have to be padded - min_sent_length = tf.math.reduce_min(sent_lengths) - max_sent_length = tf.math.reduce_max(sent_lengths) - if min_sent_length != max_sent_length: - assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" - # finished sents are filled with pad_token - padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id - - # create length masks for tf.where operation - broad_casted_sent_lengths = tf.broadcast_to( - tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length] - ) - broad_casted_range = tf.transpose( - tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size]) - ) - - decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding) - else: - decoded = input_ids - - if return_dict_in_generate: - if do_sample: - if self.config.is_encoder_decoder: - return TFSampleEncoderDecoderOutput( - sequences=decoded, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return TFSampleDecoderOnlyOutput( - sequences=decoded, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - if self.config.is_encoder_decoder: - return TFGreedySearchEncoderDecoderOutput( - sequences=decoded, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return TFGreedySearchDecoderOnlyOutput( - sequences=decoded, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return decoded + return self._generate_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + use_cache=use_cache, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, + return_dict_in_generate=return_dict_in_generate, + **model_kwargs, + ) def _generate_beam_search( self, @@ -1761,11 +1492,6 @@ class TFGenerationMixin: input_ids, return_dict_in_generate, model_kwargs ) - # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger - # refactor of all generation models in TF. `past` should be - # optional everywhere and not be set equal to encoder_outputs - model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None - # 4. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: # if encoder-decoder then `input_ids` come from `decoder_start_token_id` @@ -1787,6 +1513,7 @@ class TFGenerationMixin: # 5. determine generation mode # TODO(Matt, Joao, Patrick) - add more use cases here is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True # 6. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( @@ -1804,6 +1531,10 @@ class TFGenerationMixin: f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) + # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all + # generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs. + model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None + # 8. run greedy search return self.greedy_search( input_ids, @@ -1816,6 +1547,35 @@ class TFGenerationMixin: **model_kwargs, ) + elif is_sample_gen_mode: + # 8. prepare logits warper + logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) + + # 9. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all + # generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs. + model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None + + # 10. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + **model_kwargs, + ) + # TODO(Matt, Joao, Patrick) - add more sub-generation methods here def _prepare_attention_mask_for_generation( @@ -1908,6 +1668,36 @@ class TFGenerationMixin: "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) + @staticmethod + def _expand_inputs_for_generation( + input_ids: tf.Tensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[tf.Tensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ) -> Tuple[tf.Tensor, Dict[str, Any]]: + expanded_return_idx = tf.reshape( + tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1) + ) + input_ids = tf.gather(input_ids, expanded_return_idx, axis=0) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx, axis=0) + + if attention_mask is not None: + model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx, axis=0) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + encoder_outputs["last_hidden_state"] = tf.gather( + encoder_outputs.last_hidden_state, expanded_return_idx, axis=0 + ) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None): # TODO(Patrick) - adapt this function when making `generate` more flexible # for all kinds of input types @@ -1956,6 +1746,34 @@ class TFGenerationMixin: return model_kwargs + def _get_logits_warper( + self, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> TFLogitsProcessorList: + """ + This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`] + instances used for multinomial sampling. + """ + + # init warp parameters + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + temperature = temperature if temperature is not None else self.config.temperature + # instantiate warpers list + warpers = TFLogitsProcessorList() + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if temperature is not None and temperature != 1.0: + warpers.append(TFTemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + warpers.append(TFTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1)) + if top_p is not None and top_p < 1.0: + warpers.append(TFTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) + return warpers + def _get_logits_processor( self, repetition_penalty: float, @@ -2029,8 +1847,8 @@ class TFGenerationMixin: return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Additional model specific keyword arguments will be forwarded to the `call` function of the model. If + model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_tf_utils.TFGreedySearchDecoderOnlyOutput`], @@ -2043,13 +1861,13 @@ class TFGenerationMixin: ```python >>> from transformers import ( - ... TFAutoTokenizer, + ... AutoTokenizer, ... TFAutoModelForCausalLM, ... TFLogitsProcessorList, ... TFMinLengthLogitsProcessor, ... ) - >>> tokenizer = TFAutoTokenizer.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = TFAutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token @@ -2195,6 +2013,230 @@ class TFGenerationMixin: else: return input_ids + def sample( + self, + input_ids: tf.Tensor, + logits_processor: Optional[TFLogitsProcessorList] = None, + logits_warper: Optional[TFLogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[TFSampleOutput, tf.Tensor]: + r""" + Generates sequences for models with a language modeling head using multinomial sampling. + + Parameters: + + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`TFLogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + logits_warper (`TFLogitsProcessorList`, *optional*): + An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`] + used to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + model_kwargs: + Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an + encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_tf_utils.TFSampleDecoderOnlyOutput`], [`~generation_tf_utils.TFSampleEncoderDecoderOutput`] + or `tf.Tensor`: A `tf.Tensor` containing the generated tokens (default behaviour) or a + [`~generation_tf_utils.TFSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_tf_utils.TFSampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... TFAutoModelForCausalLM, + ... TFLogitsProcessorList, + ... TFMinLengthLogitsProcessor, + ... TFTopKLogitsWarper, + ... TFTemperatureLogitsWarper, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = TFAutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids + + >>> # instantiate logits processors + >>> logits_processor = TFLogitsProcessorList( + ... [ + ... TFMinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = TFLogitsProcessorList( + ... [ + ... TFTopKLogitsWarper(50), + ... TFTemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + ```""" + + # init values + logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() + + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs` + # to be wrapped into `past` variable. This is a bad design and needs to be updated. + # Remove the following lines when updating all encoder-decoder models + encoder_outputs = model_kwargs.pop("encoder_outputs", None) + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None + encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None + + # keep track of which sequences are already finished + unfinished_sequences = tf.ones_like(input_ids[:, 0]) + cur_len = input_ids.shape[-1] + + while cur_len < max_length: + # TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation` + # in all models + model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"] + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + next_tokens = tf.squeeze( + tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1 + ) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + eos_in_sents = next_tokens == eos_token_id + # if sentence is unfinished and the token to add is eos + is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( + unfinished_sequences, tf.cast(eos_in_sents, tf.int32) + ) + + # unfinished_sequences is set to zero if eos in sentence + unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos + + # stop when each sentence is finished, or if we exceed the maximum length + if tf.math.reduce_max(unfinished_sequences) == 0: + break + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return TFSampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return TFSampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): # create logit penalties for already seen input_ids @@ -2292,7 +2334,6 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None] logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) - if top_p < 1.0: sorted_indices = tf.argsort(logits, direction="DESCENDING") sorted_logits = tf.gather( diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6700663afff..379d87c484f 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -556,8 +556,8 @@ class GenerationMixin: input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, - attention_mask: torch.LongTensor = None, - encoder_outputs: ModelOutput = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( @@ -617,11 +617,11 @@ class GenerationMixin: def _get_logits_warper( self, - top_k: int = None, - top_p: float = None, - typical_p: float = None, - temperature: float = None, - num_beams: int = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + temperature: Optional[float] = None, + num_beams: Optional[int] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index ae7ffee3fb9..29495321013 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -31,6 +31,13 @@ class TFLogitsProcessorList(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFLogitsWarper(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFMinLengthLogitsProcessor(metaclass=DummyObject): _backends = ["tf"] @@ -59,6 +66,27 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFTemperatureLogitsWarper(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFTopKLogitsWarper(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFTopPLogitsWarper(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + def tf_top_k_top_p_filtering(*args, **kwargs): requires_backends(tf_top_k_top_p_filtering, ["tf"]) diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index d8873bc02f1..5ffc6843a1f 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase): scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length return scores - def test_min_lenght_dist_processor(self): + def test_min_length_dist_processor(self): vocab_size = 20 batch_size = 4 eos_token_id = 0 diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index 3869ddbc40b..2dc366f05c7 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -16,6 +16,8 @@ import unittest +import numpy as np + from transformers import is_tf_available from transformers.testing_utils import require_tf @@ -29,6 +31,9 @@ if is_tf_available(): TFNoBadWordsLogitsProcessor, TFNoRepeatNGramLogitsProcessor, TFRepetitionPenaltyLogitsProcessor, + TFTemperatureLogitsWarper, + TFTopKLogitsWarper, + TFTopPLogitsWarper, ) from transformers.tf_utils import set_tensor_by_indices_to_value @@ -38,7 +43,7 @@ if is_tf_available(): @require_tf class TFLogitsProcessorTest(unittest.TestCase): def _get_uniform_logits(self, batch_size: int, length: int): - scores = tf.ones((batch_size, length), dtype=tf.float32) / length + scores = np.ones((batch_size, length), dtype=np.float32) / length return scores def test_min_length_dist_processor(self): @@ -60,6 +65,37 @@ class TFLogitsProcessorTest(unittest.TestCase): scores_before_min_length = min_dist_processor(input_ids, scores) self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy()) + def test_temperature_dist_warper(self): + input_ids = None + length = 20 + + scores = self._get_uniform_logits(batch_size=2, length=length) + + # tweak scores to not be uniform anymore + scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch + scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch + + # compute softmax + probs = tf.nn.softmax(scores, axis=-1) + + temp_dist_warper_sharper = TFTemperatureLogitsWarper(temperature=0.5) + temp_dist_warper_smoother = TFTemperatureLogitsWarper(temperature=1.3) + + warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1) + warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1) + + # uniform distribution stays uniform + tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3) + tf.debugging.assert_near(probs[0, :], warped_prob_smooth[0, :], atol=1e-3) + + # sharp peaks get higher, valleys get lower + self.assertLess(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_sharp[1, :])) + self.assertGreater(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_sharp[1, :])) + + # smooth peaks get lower, valleys get higher + self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :])) + self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :])) + def test_repetition_penalty_dist_process(self): input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) vocab_size = 10 @@ -82,6 +118,73 @@ class TFLogitsProcessorTest(unittest.TestCase): self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2) + def test_top_k_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create ramp distribution + ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() + ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size + + top_k_warp = TFTopKLogitsWarper(3) + + scores = top_k_warp(input_ids, ramp_logits) + + # check that correct tokens are filtered + self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False]) + self.assertListEqual(tf.math.is_inf(scores[1]).numpy().tolist(), 2 * [True] + 3 * [False] + 5 * [True]) + + # check special cases + length = 5 + + logits = self._get_uniform_logits(batch_size=batch_size, length=length) + top_k_warp_safety_check = TFTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) + + scores = top_k_warp_safety_check(input_ids, logits) + # uniform dist is not changed + self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0]) + + ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy() + scores = top_k_warp_safety_check(input_ids, ramp_logits) + + # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified + self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2]) + + def test_top_p_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper) + dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32)) + + top_p_warp = TFTopPLogitsWarper(0.7) + filtered_dist = tf.exp(top_p_warp(input_ids, dist)) + + # dist should be filtered to keep min num values so that sum is >= 0.7 + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32) + tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3) + + # check edge cases with negative and extreme logits + ramp_logits = np.broadcast_to( + np.arange(vocab_size, dtype=np.float32)[None, :], (batch_size, vocab_size) + ).copy() - (vocab_size // 2) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = top_p_warp(input_ids, ramp_logits) + + # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps + # 2. + self.assertListEqual( + tf.math.reduce_sum(tf.where(filtered_dist != 0.0, 1, 0), axis=-1).numpy().tolist(), [3, 2] + ) + def test_no_repeat_ngram_dist_processor(self): vocab_size = 3 batch_size = 2 @@ -140,13 +243,19 @@ class TFLogitsProcessorTest(unittest.TestCase): # instantiate all dist processors min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) + top_k_warp = TFTopKLogitsWarper(3) + top_p_warp = TFTopPLogitsWarper(0.8) no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list scores = min_dist_proc(input_ids, scores) + scores = temp_dist_warp(input_ids, scores) scores = rep_penalty_proc(input_ids, scores) + scores = top_k_warp(input_ids, scores) + scores = top_p_warp(input_ids, scores) scores = no_repeat_proc(input_ids, scores) scores = no_bad_words_dist_proc(input_ids, scores) @@ -154,7 +263,10 @@ class TFLogitsProcessorTest(unittest.TestCase): processor = TFLogitsProcessorList( [ min_dist_proc, + temp_dist_warp, rep_penalty_proc, + top_k_warp, + top_p_warp, no_repeat_proc, no_bad_words_dist_proc, ] diff --git a/tests/gpt2/test_modeling_tf_gpt2.py b/tests/gpt2/test_modeling_tf_gpt2.py index 0952ff8f704..4bc1c876e02 100644 --- a/tests/gpt2/test_modeling_tf_gpt2.py +++ b/tests/gpt2/test_modeling_tf_gpt2.py @@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): "top_k": 500, "top_p": 0.9, } - tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation - output_ids = model.generate(input_ids, **generation_kwargs) + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation + output_ids = model.generate(input_ids, **generation_kwargs) + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) expected_output_string = [ diff --git a/tests/t5/test_modeling_tf_t5.py b/tests/t5/test_modeling_tf_t5.py index 77d0f674680..49d020d17bc 100644 --- a/tests/t5/test_modeling_tf_t5.py +++ b/tests/t5/test_modeling_tf_t5.py @@ -497,9 +497,11 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): "top_k": 500, "top_p": 0.9, } - tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation - output_ids = model.generate(input_ids, **generation_kwargs) + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation + output_ids = model.generate(input_ids, **generation_kwargs) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 142bff7cae0..a8acf2d5631 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -947,7 +947,7 @@ class TFModelTesterMixin: if config.bos_token_id is None: # if bos token id is not defined model needs input_ids - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): model.generate(do_sample=True, max_length=5) # num_return_sequences = 1 self._check_generated_ids(model.generate(input_ids, do_sample=True))