diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 64c47a1ea62..5d619d1d19c 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import warnings from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -32,6 +33,7 @@ from ..models.auto import ( ) from ..tf_utils import shape_list, stable_softmax from ..utils import ModelOutput, logging +from .configuration_utils import GenerationConfig from .tf_logits_process import ( TFForcedBOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor, @@ -449,6 +451,11 @@ class TFGenerationMixin: supports_xla_generation = True + def prepare_inputs_for_generation(self, *args, **kwargs): + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + ) + def adjust_logits_during_generation( self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs ): @@ -475,7 +482,7 @@ class TFGenerationMixin: Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ - if not hasattr(self, "prepare_inputs_for_generation"): + if not self.can_generate(): generate_compatible_mappings = [ TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, @@ -520,153 +527,43 @@ class TFGenerationMixin: def generate( self, - input_ids=None, - max_length=None, - max_new_tokens=None, - min_length=None, - do_sample=None, - early_stopping=None, - num_beams=None, - temperature=None, - penalty_alpha=None, - top_k=None, - top_p=None, - repetition_penalty=None, - bad_words_ids=None, - bos_token_id=None, - pad_token_id=None, - eos_token_id=None, - length_penalty=None, - no_repeat_ngram_size=None, - num_return_sequences=None, - attention_mask=None, - decoder_start_token_id=None, - use_cache=None, + input_ids: Optional[tf.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, seed=None, - output_scores=None, - output_attentions=None, - output_hidden_states=None, - return_dict_in_generate=None, - forced_bos_token_id=None, - forced_eos_token_id=None, - suppress_tokens=None, - begin_suppress_tokens=None, - forced_decoder_ids=None, - **model_kwargs, + **kwargs, ) -> Union[TFGenerateOutput, tf.Tensor]: r""" - Generates sequences of token ids for models with a language modeling head. The method supports the following - generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - - *greedy decoding* by calling [`~generation.TFGenerationMixin.greedy_search`] if `num_beams=1` and - `do_sample=False`. - - *contrastive search* by calling [`~generation.TFGenerationMixin.contrastive_search`] if `penalty_alpha>0` - and `top_k>1` - - *multinomial sampling* by calling [`~generation.TFGenerationMixin.sample`] if `num_beams=1` and - `do_sample=True`. - - *beam-search decoding* by calling [`~generation.TFGenerationMixin.beam_search`] if `num_beams>1` and - `do_sample=False`. + Generates sequences of token ids for models with a language modeling head. - Adapted in part from [Facebook's XLM beam search - code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529). + - Apart from `input_ids` and `attention_mask`, all the arguments below will default to the value of the attribute - of the same name inside the [`PretrainedConfig`] of the model. The default values indicated are the default - values of those config. + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - Most of these parameters are explained in more detail in [this blog - post](https://huggingface.co/blog/how-to-generate). + For a complete overview of generate, check the [following + guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). + + Parameters: input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*): The sequence used as a prompt for the generation. If `None` the method initializes it with `bos_token_id` and a batch size of 1. - max_length (`int`, *optional*, defaults to `model.config.max_length`): - The maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in - the prompt. - max_new_tokens (`int`, *optional*): - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. - min_length (`int`, *optional*, defaults to 10): - The minimum length of the sequence to be generated. - do_sample (`bool`, *optional*, defaults to `False`): - Whether or not to use sampling ; use greedy decoding otherwise. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - temperature (`float`, *optional*, defaults to 1.0): - The value used to module the next token probabilities. - penalty_alpha (`float`, *optional*): - The values balance the model confidence and the degeneration penalty in contrastive search decoding. - top_k (`int`, *optional*, defaults to 50): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p (`float`, *optional*, defaults to 1.0): - If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher - are kept for generation. - repetition_penalty (`float`, *optional*, defaults to 1.0): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - bos_token_id (`int`, *optional*): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size can only occur once. - bad_words_ids(`List[int]`, *optional*): - List of token ids that are not allowed to be generated. In order to get the tokens of the words that - should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. - attention_mask (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens - that are not masked, and 0 for masked tokens. - - If not provided, will default to a tensor the same shape as `input_ids` that masks the pad token. - - [What are attention masks?](../glossary#attention-mask) - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. seed (`List[int]`, *optional*): Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the `seed` argument from stateless functions in `tf.random`. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - forced_bos_token_id (`int`, *optional*): - The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful - for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be - the target language token. - forced_eos_token_id (`int`, *optional*): - The id of the token to force as the last generated token when `max_length` is reached. - suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`): - A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set - their log probs to `-inf` so that they are not sampled. - begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`): - A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` - logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): - A list of pairs of integers which indicates a mapping from generation indices to token indices that - will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always - be a token of index 123. - model_kwargs: - Additional model specific kwargs will be forwarded to the `call` function of the model. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when @@ -690,59 +587,92 @@ class TFGenerationMixin: Examples: + Greedy decoding, using the default generation configuration and ad hoc modifications: + ```python - tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # Initialize tokenizer - model = TFAutoModelWithLMHead.from_pretrained("distilgpt2") - # Greedy decoding - outputs = model.generate(max_length=40) - print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + >>> from transformers import AutoTokenizer, TFAutoModelForCausalLM - tokenizer = AutoTokenizer.from_pretrained("openai-gpt") - model = TFAutoModelWithLMHead.from_pretrained("openai-gpt") - input_context = "The dog" - input_ids = tokenizer.encode(input_context, return_tensors="tf") # encode input context - # Generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' - outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) - # 3 output sequences were generated - for i in range(3): - print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = TFAutoModelForCausalLM.from_pretrained("gpt2") - tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - model = TFAutoModelWithLMHead.from_pretrained("distilgpt2") - input_context = "The dog" - input_ids = tokenizer.encode(input_context, return_tensors="tf") - # Generate 3 candidates using sampling - outputs = model.generate( - input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True - ) - # 3 output sequences were generated - for i in range(3): - print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}") + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids - tokenizer = AutoTokenizer.from_pretrained("ctrl") - model = TFAutoModelWithLMHead.from_pretrained("ctrl") - # "Legal" is one of the control codes for ctrl - input_context = "Legal My neighbor is" - input_ids = tokenizer.encode(input_context, return_tensors="tf") - outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) - print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + >>> # Generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = TFAutoModelWithLMHead.from_pretrained("gpt2") - input_context = "My cute dog" - bad_words_ids = [ - tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ["idiot", "stupid", "shut up"] - ] - input_ids = tokenizer.encode(input_context, return_tensors="tf") - # generate sequences without allowing bad_words to be generated - outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) + Multinomial sampling, modifying an existing generation configuration: + + ```python + >>> from transformers import AutoTokenizer, TFAutoModelForCausalLM, GenerationConfig + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = TFAutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids + + >>> # Sample up to 30 tokens + >>> generation_config = GenerationConfig.from_pretrained("gpt2") + >>> generation_config.max_length = 30 + >>> generation_config.do_sample = True + >>> outputs = model.generate(input_ids, generation_config=generation_config, seed=[0, 0]) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["Today I believe we can finally start taking a bold stand against climate change and climate change mitigation efforts such as President Obama's climate ban and President Trump's"] + ``` + + Beam-search decoding, using a freshly initialized generation configuration: + + ```python + >>> from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, GenerationConfig + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = TFAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="tf").input_ids + + >>> generation_config = GenerationConfig( + ... max_length=64, + ... num_beams=5, + ... bos_token_id=0, + ... eos_token_id=0, + ... decoder_start_token_id=58100, + ... pad_token_id=58100, + ... bad_words_ids=[[58100]], + ... ) + >>> outputs = model.generate(input_ids, generation_config=generation_config) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # 0. Validate the `.generate()` call + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) - # 1. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) + # 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) if input_ids is not None: if isinstance(input_ids, tf.Tensor) and input_ids.dtype.is_floating: pass @@ -750,8 +680,8 @@ class TFGenerationMixin: pass else: input_ids = tf.cast(input_ids, tf.int32) - if attention_mask is not None: - attention_mask = tf.cast(attention_mask, tf.int32) + if model_kwargs.get("attention_mask") is not None: + model_kwargs["attention_mask"] = tf.cast(model_kwargs["attention_mask"], tf.int32) if "decoder_input_ids" in model_kwargs: if ( isinstance(model_kwargs["decoder_input_ids"], tf.Tensor) @@ -765,44 +695,18 @@ class TFGenerationMixin: else: model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32) - # 2. Set generation parameters if not already defined - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - - forced_bos_token_id = ( - forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id - ) - forced_eos_token_id = ( - forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id - ) - - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) - - num_beams = num_beams if num_beams is not None else self.config.num_beams - do_sample = do_sample if do_sample is not None else self.config.do_sample - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) - - if pad_token_id is None and eos_token_id is not None: - if attention_mask is None: + # 3. Set generation parameters if not already defined + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask") is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence") - pad_token_id = eos_token_id + logger.warning( + f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate" + " sequence" + ) + generation_config.pad_token_id = generation_config.eos_token_id use_xla = not tf.executing_eagerly() if use_xla and not self.supports_xla_generation: @@ -810,241 +714,242 @@ class TFGenerationMixin: "The selected model does not support Graph mode nor XLA generation (e.g. from tf.function())" ) - # 3. Define model inputs - input_ids = self._prepare_model_inputs(input_ids, bos_token_id) + # 4. Define model inputs + input_ids = self._prepare_model_inputs(input_ids, generation_config.bos_token_id) # inputs_ids now has to be defined and cannot be None anymore batch_size = shape_list(input_ids)[0] - # 4. Prepare other model kwargs - if output_attentions is not None: - model_kwargs["output_attentions"] = output_attentions - if output_hidden_states is not None: - model_kwargs["output_hidden_states"] = output_hidden_states - if use_cache is not None: - model_kwargs["use_cache"] = use_cache - if attention_mask is not None: - model_kwargs["attention_mask"] = attention_mask + # 5. Prepare other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, pad_token_id, eos_token_id + input_ids, generation_config.pad_token_id, generation_config.eos_token_id ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: - if pad_token_id is not None and tf.math.reduce_any(input_ids[:, -1] == pad_token_id): + if generation_config.pad_token_id is not None and tf.math.reduce_any( + input_ids[:, -1] == generation_config.pad_token_id + ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) - # 5. Prepare model inputs which will be used for auto-regressive generation + # 6. Prepare model inputs which will be used for auto-regressive generation if self.config.is_encoder_decoder: # if encoder-decoder, we create encoder_outputs and add to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # if encoder-decoder then `input_ids` come from `decoder_start_token_id` input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, - decoder_start_token_id=decoder_start_token_id, - bos_token_id=bos_token_id, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, model_kwargs=model_kwargs, ) - # 6. Prepare `max_length` depending on other stopping criteria. + # 7. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] - if max_length is None and max_new_tokens is None: + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to " - f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " - "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " - "using `max_new_tokens` to control the maximum length of the generation.", + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids_seq_length - elif max_length is not None and max_new_tokens is not None: + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: raise ValueError( "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" " limit to the generated output length. Remove one of those arguments. Please refer to the" " documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - # default to config if still None - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - if min_length is not None and min_length > max_length: + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( - f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " - f"length ({max_length})" + f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" ) - if input_ids_seq_length >= max_length: + if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {max_length}. This can lead to unexpected behavior. You should consider increasing" - "`max_new_tokens`." + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing`max_new_tokens`." ) - # 7. determine generation mode + # 8. determine generation mode is_contrastive_search_gen_mode = ( - top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 ) - is_greedy_gen_mode = not is_contrastive_search_gen_mode and (num_beams == 1) and do_sample is False - is_beam_gen_mode = not is_contrastive_search_gen_mode and (num_beams > 1) and do_sample is False - is_sample_gen_mode = (num_beams == 1) and do_sample is True - is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True + is_greedy_gen_mode = ( + not is_contrastive_search_gen_mode + and (generation_config.num_beams == 1) + and generation_config.do_sample is False + ) + is_beam_gen_mode = ( + not is_contrastive_search_gen_mode + and (generation_config.num_beams > 1) + and generation_config.do_sample is False + ) + is_sample_gen_mode = (generation_config.num_beams == 1) and generation_config.do_sample is True + is_beam_sample_gen_mode = (generation_config.num_beams > 1) and generation_config.do_sample is True - # 8. prepare distribution pre_processing samplers + # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, + generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - forced_decoder_ids=forced_decoder_ids, ) - # 9. go into different generation modes + # 10. go into different generation modes if is_greedy_gen_mode: - if num_return_sequences > 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." ) - # 10. run greedy search + # 11. run greedy search return self.greedy_search( input_ids, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, logits_processor=logits_processor, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, **model_kwargs, ) elif is_contrastive_search_gen_mode: - if num_return_sequences > 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." ) - # 10. run contrastive search + # 11. run contrastive search return self.contrastive_search( input_ids, - top_k=top_k, - penalty_alpha=penalty_alpha, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, logits_processor=logits_processor, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, **model_kwargs, ) elif is_sample_gen_mode: - # 10. prepare logits warper - logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config=generation_config) - # 11. expand input_ids with `num_return_sequences` additional sequences per batch + # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_return_sequences, + expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run sample + # 13. run sample return self.sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, seed=seed, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, **model_kwargs, ) elif is_beam_gen_mode: - if num_beams < num_return_sequences: + if generation_config.num_beams < generation_config.num_return_sequences: raise ValueError( - "Beam search decoding cannot return more sequences than it has beams. Please set " - f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)" + "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >=" + f" num_return_sequences, got {generation_config.num_beams} and" + f" {generation_config.num_return_sequences} (respectivelly)" ) - # 10. broadcast inputs to the desired number of beams - input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) + # 11. broadcast inputs to the desired number of beams + input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( - model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams + model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams ) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = self._expand_to_num_beams( - model_kwargs["attention_mask"], num_beams=num_beams + model_kwargs["attention_mask"], num_beams=generation_config.num_beams ) - # 11. run beam search + # 12. run beam search return self.beam_search( input_ids, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - length_penalty=length_penalty, - early_stopping=early_stopping, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + length_penalty=generation_config.length_penalty, + early_stopping=generation_config.early_stopping, logits_processor=logits_processor, - return_dict_in_generate=return_dict_in_generate, - num_return_sequences=num_return_sequences, + return_dict_in_generate=generation_config.return_dict_in_generate, + num_return_sequences=generation_config.num_return_sequences, **model_kwargs, ) elif is_beam_sample_gen_mode: - if num_beams < num_return_sequences: + if generation_config.num_beams < generation_config.num_return_sequences: raise ValueError( - "Beam search decoding cannot return more sequences than it has beams. Please set " - f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)" + "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >=" + f" num_return_sequences, got {generation_config.num_beams} and" + f" {generation_config.num_return_sequences} (respectivelly)" ) - # 10. prepare logits warper - logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config=generation_config) - # 11. broadcast inputs to the desired number of beams - input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) + # 12. broadcast inputs to the desired number of beams + input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( - model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams + model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams ) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = self._expand_to_num_beams( - model_kwargs["attention_mask"], num_beams=num_beams + model_kwargs["attention_mask"], num_beams=generation_config.num_beams ) - # 12. run beam sample (beam search with sampling) + # 13. run beam sample (beam search with sampling) return self.beam_search( input_ids, do_sample=True, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - length_penalty=length_penalty, - early_stopping=early_stopping, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + length_penalty=generation_config.length_penalty, + early_stopping=generation_config.early_stopping, logits_processor=logits_processor, logits_warper=logits_warper, - return_dict_in_generate=return_dict_in_generate, - num_return_sequences=num_return_sequences, + return_dict_in_generate=generation_config.return_dict_in_generate, + num_return_sequences=generation_config.num_return_sequences, **model_kwargs, ) @@ -1108,26 +1013,16 @@ class TFGenerationMixin: # retrieve decoder_start_token_id for encoder-decoder models # fall back to bos_token_id if necessary decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "decoder_start_token_id") - and self.config.decoder.decoder_start_token_id is not None - ): - return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): - return self.config.decoder.bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @@ -1332,46 +1227,30 @@ class TFGenerationMixin: def _get_logits_warper( self, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, + generation_config: GenerationConfig, ) -> TFLogitsProcessorList: """ This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`] instances used for multinomial sampling. """ - # init warp parameters - top_k = top_k if top_k is not None else self.config.top_k - top_p = top_p if top_p is not None else self.config.top_p - temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = TFLogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if temperature is not None and temperature != 1.0: - warpers.append(TFTemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - warpers.append(TFTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1)) - if top_p is not None and top_p < 1.0: - warpers.append(TFTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TFTemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) return warpers def _get_logits_processor( self, - repetition_penalty: float, - no_repeat_ngram_size: int, + generation_config: GenerationConfig, input_ids_seq_length: int, - bad_words_ids: List[List[int]], - min_length: int, - max_length: int, - eos_token_id: int, - forced_bos_token_id: int, - forced_eos_token_id: int, - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[List[int]]] = None, ) -> TFLogitsProcessorList: """ This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] @@ -1379,42 +1258,45 @@ class TFGenerationMixin: """ processors = TFLogitsProcessorList() - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens - begin_suppress_tokens = ( - begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens - ) - if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): - forced_decoder_ids = self.config.forced_decoder_ids - # instantiate processors list - if repetition_penalty is not None and repetition_penalty != 1.0: - processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) - if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: - processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) - if bad_words_ids is not None: - processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) - if min_length is not None and eos_token_id is not None and min_length > 0: - processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id)) - if forced_bos_token_id is not None: - processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id)) - if forced_eos_token_id is not None: - processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) - if suppress_tokens is not None: - processors.append(TFSuppressTokensLogitsProcessor(suppress_tokens)) - if begin_suppress_tokens is not None: + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(TFNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if generation_config.bad_words_ids is not None: + processors.append( + TFNoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + ) + if ( + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > 0 + ): + processors.append(TFMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) + if generation_config.forced_bos_token_id is not None: + processors.append(TFForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) + if generation_config.forced_eos_token_id is not None: + processors.append( + TFForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ) + if generation_config.suppress_tokens is not None: + processors.append(TFSuppressTokensLogitsProcessor(generation_config.suppress_tokens)) + if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length - begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 - if forced_decoder_ids is not None: - begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced - processors.append(TFSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) - if forced_decoder_ids is not None: - processors.append(TFForceTokensLogitsProcessor(forced_decoder_ids)) + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) + if generation_config.forced_decoder_ids is not None: + begin_index += generation_config.forced_decoder_ids[-1][ + 0 + ] # generation starts after the last token that is forced + processors.append( + TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + ) + if generation_config.forced_decoder_ids is not None: + processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) return processors def greedy_search( @@ -1500,17 +1382,22 @@ class TFGenerationMixin: # 1. init greedy_search values logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) + use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) use_xla = not tf.executing_eagerly() # TODO (Joao): fix cache format or find programatic way to detect cache index # GPT2 and other models has a slightly different cache structure, with a different batch axis @@ -1546,7 +1433,7 @@ class TFGenerationMixin: input_ids = generated[:, :cur_len] else: input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs) # forward pass to get next token logits model_outputs = self( **model_inputs, @@ -1772,17 +1659,22 @@ class TFGenerationMixin: logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) + use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) use_xla = not tf.executing_eagerly() # TODO (Joao): fix cache format or find programatic way to detect cache index # GPT2 and other models has a slightly different cache structure, with a different batch axis @@ -1814,7 +1706,7 @@ class TFGenerationMixin: input_ids = generated[:, :cur_len] else: input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs) # forward pass to get next token logits model_outputs = self( **model_inputs, @@ -2091,25 +1983,30 @@ class TFGenerationMixin: logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) - output_scores = output_scores if output_scores is not None else self.config.output_scores + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping + use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) use_xla = not tf.executing_eagerly() # TODO (Joao): fix cache format or find programatic way to detect cache index # GPT2 and other models has a slightly different cache structure, with a different batch axis @@ -2199,7 +2096,9 @@ class TFGenerationMixin: input_ids = running_sequences[:, :, :cur_len] else: input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1) - model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs) + model_inputs = self.prepare_inputs_for_generation( + flatten_beam_dim(input_ids), use_cache=use_cache, **model_kwargs + ) model_outputs = self( **model_inputs, return_dict=True, @@ -2521,17 +2420,22 @@ class TFGenerationMixin: # 1. init greedy_search values logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) + use_cache = True # In contrastive search, we always use cache use_xla = not tf.executing_eagerly() # TODO (Joao): fix cache format or find programatic way to detect cache index # GPT2 and other models has a slightly different cache structure, with a different batch axis @@ -2571,8 +2475,9 @@ class TFGenerationMixin: if model_kwargs.get("past") is None: # prepare inputs - model_inputs = self.prepare_inputs_for_generation(generated[:, :cur_len], **model_kwargs) - model_inputs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation( + generated[:, :cur_len], use_cache=use_cache, **model_kwargs + ) # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save # the `encoder_outputs` @@ -2662,8 +2567,9 @@ class TFGenerationMixin: ) # compute the candidate tokens by the language model and collects their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(tf.reshape(top_k_ids, [-1, 1]), **model_kwargs) - next_model_inputs["use_cache"] = True + next_model_inputs = self.prepare_inputs_for_generation( + tf.reshape(top_k_ids, [-1, 1]), use_cache=use_cache, **model_kwargs + ) outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 91779a38b4e..d24c59b0c95 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator from .activations_tf import get_tf_activation from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save -from .generation import TFGenerationMixin +from .generation import GenerationConfig, TFGenerationMixin from .tf_utils import shape_list from .utils import ( DUMMY_INPUTS, @@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None # Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec self._set_save_spec(self.serving.input_signature[0]) @@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu """ raise NotImplementedError + def can_generate(self) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation + if "GenerationMixin" in str(self.prepare_inputs_for_generation): + return False + return True + def get_input_embeddings(self) -> tf.keras.layers.Layer: """ Returns the model's input embeddings layer. @@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu " to use it for predictions and inference." ) + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + if output_loading_info: loading_info = { "missing_keys": missing_keys, diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index feed5f00d7f..cb3170b4e53 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -15,6 +15,7 @@ """TFRAG model implementation.""" +import copy from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss context_input_ids=None, context_attention_mask=None, doc_scores=None, - max_length=None, - min_length=None, - early_stopping=None, - use_cache=None, - num_beams=None, - bos_token_id=None, - pad_token_id=None, - eos_token_id=None, - length_penalty=None, - no_repeat_ngram_size=None, - bad_words_ids=None, - num_return_sequences=None, - decoder_start_token_id=None, n_docs=None, - output_scores=None, - output_attentions=None, - output_hidden_states=None, - return_dict_in_generate=None, - **model_kwargs + generation_config=None, + **kwargs ): """ Implements TFRAG token decoding. @@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - min_length (`int`, *optional*, defaults to 10): - The minimum length of the sequence to be generated. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or - not. - use_cache: (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - bos_token_id (`int`, *optional*): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size can only occur once. - bad_words_ids(`List[int]`, *optional*): - List of token ids that are not allowed to be generated. In order to get the tokens of the words that - should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. Note that this - is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where - we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an - encoder-decoder model starts decoding with a different token than *bos*, the id of that token. n_docs (`int`, *optional*, defaults to `config.n_docs`) Number of documents to retrieve and/or number of documents for which to generate an answer. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - model_specific_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. Return: `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. """ + # Handle `generation_config` and kwargs that might update it + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - use_cache = use_cache if use_cache is not None else self.config.use_cache - num_beams = num_beams if num_beams is not None else self.config.num_beams - bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.config.generator.decoder_start_token_id - ) - - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) # retrieve docs if self.retriever is not None and context_input_ids is None: @@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss encoder_outputs = encoder( input_ids=context_input_ids, attention_mask=context_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, return_dict=True, ) decoder_input_ids = tf.fill( - (batch_size * num_beams, 1), - tf.cast(decoder_start_token_id, tf.int32), + (batch_size * generation_config.num_beams, 1), + tf.cast(generation_config.decoder_start_token_id, tf.int32), ) last_hidden_state = encoder_outputs["last_hidden_state"] @@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss return tf.reshape(tensor, new_shape) # correctly extend last_hidden_state and attention mask - context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) - encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams) + context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams) + encoder_outputs["last_hidden_state"] = extend_enc_output( + last_hidden_state, num_beams=generation_config.num_beams + ) - doc_scores = tf.repeat(doc_scores, num_beams, axis=0) + doc_scores = tf.repeat(doc_scores, generation_config.num_beams, axis=0) # define start_len & additional parameters model_kwargs["doc_scores"] = doc_scores @@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss model_kwargs["n_docs"] = n_docs pre_processor = self._get_logits_processor( - repetition_penalty=self.config.repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=self.config.generator.forced_bos_token_id, - forced_eos_token_id=self.config.generator.forced_eos_token_id, + generation_config=generation_config, input_ids_seq_length=tf.shape(decoder_input_ids)[-1], ) - if num_beams == 1: + if generation_config.num_beams == 1: return self.greedy_search( input_ids=decoder_input_ids, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, logits_processor=pre_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, **model_kwargs, ) - elif num_beams > 1: - if num_beams < num_return_sequences: + elif generation_config.num_beams > 1: + if generation_config.num_beams < generation_config.num_return_sequences: raise ValueError( - "Beam search decoding cannot return more sequences than it has beams. Please set " - f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)" + "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >=" + f" num_return_sequences, got {generation_config.num_beams} and" + f" {generation_config.num_return_sequences} (respectivelly)" ) def unflatten_beam_dim(tensor): """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" shape = shape_list(tensor) - return tf.reshape(tensor, [-1, num_beams] + shape[1:]) + return tf.reshape(tensor, [-1, generation_config.num_beams] + shape[1:]) decoder_input_ids = unflatten_beam_dim(decoder_input_ids) model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"]) @@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss return self.beam_search( input_ids=decoder_input_ids, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, logits_processor=pre_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, **model_kwargs, ) else: - raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}") + raise ValueError( + f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}" + ) def get_input_embeddings(self): return self.rag.generator.get_input_embeddings() diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 4dcc14d8070..f8ca8506262 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1824,18 +1824,18 @@ class TFModelTesterMixin: model.train_on_batch(test_batch, test_batch_labels) def _test_xla_generate(self, **generate_kwargs): - def _generate_and_check_results(model, config, inputs_dict): + def _generate_and_check_results(model, inputs_dict): if "input_ids" in inputs_dict: inputs = inputs_dict["input_ids"] # make sure there are no pad tokens in prompt, which may trigger unwanted behavior - if config.pad_token_id is not None: + if model.generation_config.pad_token_id is not None: if config.pad_token_id == 0: - new_pad_token = config.pad_token_id + 1 + new_pad_token = model.generation_config.pad_token_id + 1 else: - new_pad_token = config.pad_token_id - 1 + new_pad_token = model.generation_config.pad_token_id - 1 else: new_pad_token = None - inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token) + inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token) elif "input_features" in inputs_dict: inputs = inputs_dict["input_features"] else: @@ -1854,10 +1854,10 @@ class TFModelTesterMixin: model = model_class(config) if model.supports_xla_generation: - _generate_and_check_results(model, config, inputs_dict) + _generate_and_check_results(model, inputs_dict) else: with self.assertRaises(ValueError): - _generate_and_check_results(model, config, inputs_dict) + _generate_and_check_results(model, inputs_dict) def test_xla_generate_fast(self): """