Generate: use GenerationConfig as the basis for .generate() parametrization (#20388)

* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Joao Gante 2022-12-15 18:27:20 +00:00 committed by GitHub
parent b1706f6908
commit 4bc723f87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 694 additions and 733 deletions

View File

@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`]. - TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`]. - Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate ---> Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
of the generation method.
All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
```python
from transformers import AutoModelForCausalLM, GenerationConfig
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
# Inspect the default generation configuration
print(model.generation_config)
# Set a new default generation configuration
generation_config = GenerationConfig(
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
```
<Tip>
If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
into resource limitations. Make sure you double-check the defaults in the documentation.
</Tip>
You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
other for summarization with beam search).
```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
translation_generation_config = GenerationConfig(
num_beams=4,
early_stopping=True,
decoder_start_token_id=0,
eos_token_id=model.config.eos_token_id,
pad_token=model.config.pad_token_id,
)
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
# config as follows
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
# You could then use the named generation config file to parameterize generation
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Les fichiers de configuration sont faciles à utiliser !']
```
Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
to parameterize it.
## GenerationConfig ## GenerationConfig
[[autodoc]] generation.GenerationConfig [[autodoc]] generation.GenerationConfig
- from_pretrained - from_pretrained
- from_model_config
- save_pretrained - save_pretrained
## GenerationMixin ## GenerationMixin

View File

@ -20,6 +20,7 @@ import os
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from .. import __version__ from .. import __version__
from ..configuration_utils import PretrainedConfig
from ..utils import ( from ..utils import (
GENERATION_CONFIG_NAME, GENERATION_CONFIG_NAME,
PushToHubMixin, PushToHubMixin,
@ -36,7 +37,23 @@ logger = logging.get_logger(__name__)
class GenerationConfig(PushToHubMixin): class GenerationConfig(PushToHubMixin):
r""" r"""
Class that holds a configuration for a generation task. Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`.
<Tip> <Tip>
@ -45,6 +62,9 @@ class GenerationConfig(PushToHubMixin):
</Tip> </Tip>
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Arg: Arg:
> Parameters that control the length of the output > Parameters that control the length of the output
@ -73,6 +93,9 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*): penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding. The values balance the model confidence and the degeneration penalty in contrastive search decoding.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
> Parameters for manipulation of the model output logits > Parameters for manipulation of the model output logits
@ -108,13 +131,13 @@ class GenerationConfig(PushToHubMixin):
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word. can allow different forms of each word.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
renormalize_logits (`bool`, *optional*, defaults to `False`): renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization. are normalized but some logit processors or warpers break the normalization.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use of
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`): forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
@ -191,6 +214,7 @@ class GenerationConfig(PushToHubMixin):
self.num_beams = kwargs.pop("num_beams", 1) self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None) self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.use_cache = kwargs.pop("use_cache", True)
# Parameters for manipulation of the model output logits # Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0) self.temperature = kwargs.pop("temperature", 1.0)
@ -202,7 +226,9 @@ class GenerationConfig(PushToHubMixin):
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.force_word_ids = kwargs.pop("force_word_ids", None) self.force_words_ids = kwargs.pop("force_words_ids", None)
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
self.constraints = kwargs.pop("constraints", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
@ -230,12 +256,20 @@ class GenerationConfig(PushToHubMixin):
# Wild card # Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {}) self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface. # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub
# interface.
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None) self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__) self.transformers_version = kwargs.pop("transformers_version", __version__)
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
self_dict.pop(metadata_field, None)
other_dict.pop(metadata_field, None)
return self_dict == other_dict
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
@ -484,18 +518,11 @@ class GenerationConfig(PushToHubMixin):
kwargs["_commit_hash"] = config_dict["_commit_hash"] kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict) config = cls(**config_dict)
unused_kwargs = config.update(**kwargs)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info(f"Generate config {config}") logger.info(f"Generate config {config}")
if return_unused_kwargs: if return_unused_kwargs:
return config, kwargs return config, unused_kwargs
else: else:
return config return config
@ -568,3 +595,54 @@ class GenerationConfig(PushToHubMixin):
""" """
with open(json_file_path, "w", encoding="utf-8") as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff)) writer.write(self.to_json_string(use_diff=use_diff))
@classmethod
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
Args:
model_config (`PretrainedConfig`):
The model config that will be used to instantiate the generation config.
Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])
config._from_model_config = True
return config
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

File diff suppressed because it is too large Load Diff

View File

@ -39,7 +39,7 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import GenerationMixin from .generation import GenerationConfig, GenerationMixin
from .pytorch_utils import ( # noqa: F401 from .pytorch_utils import ( # noqa: F401
Conv1D, Conv1D,
apply_chunking_to_forward, apply_chunking_to_forward,
@ -1024,6 +1024,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.config = config self.config = config
self.name_or_path = config.name_or_path self.name_or_path = config.name_or_path
self.warnings_issued = {} self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
def post_init(self): def post_init(self):
""" """
@ -1106,6 +1107,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
return getattr(self, self.base_model_prefix, self) return getattr(self, self.base_model_prefix, self)
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
""" """
Returns the model's input embeddings. Returns the model's input embeddings.
@ -2477,6 +2490,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Set model in evaluation mode to deactivate DropOut modules by default # Set model in evaluation mode to deactivate DropOut modules by default
model.eval() model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Dispatch model with hooks on all devices if necessary # Dispatch model with hooks on all devices if necessary
if device_map is not None: if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index) dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""RAG model implementation.""" """RAG model implementation."""
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
@ -21,7 +22,7 @@ import torch
from torch import nn from torch import nn
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@ -1384,33 +1385,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
context_input_ids: Optional[torch.LongTensor] = None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None, doc_scores: Optional[torch.FloatTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
early_stopping: Optional[bool] = None,
use_cache: Optional[bool] = None,
num_beams: Optional[int] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[List[List[int]]] = None,
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
n_docs: Optional[int] = None, n_docs: Optional[int] = None,
generation_config: Optional[GenerationConfig] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
forced_bos_token_id: Optional[int] = None, **kwargs
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs
) -> torch.LongTensor: ) -> torch.LongTensor:
""" """
Implements RAG token decoding. Implements RAG token decoding.
@ -1444,51 +1424,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
not.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
diversity_penalty (`float`, *optional*, defaults to 0.0):
This value is subtracted from a beam's score if it generates a token same as any beam from other group
at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
enabled.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`) n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which has the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*): prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
@ -1497,53 +1441,30 @@ class RagTokenForGeneration(RagPreTrainedModel):
constrained generation conditioned on the prefix, as described in [Autoregressive Entity constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904). Retrieval](https://arxiv.org/abs/2010.00904).
logits_processor (`LogitsProcessorList`, *optional*): logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and a Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown. config an error is thrown.
stopping_criteria (`StoppingCriteriaList`, *optional*): stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a model's config. If a stopping criteria is passed that is already created with the arguments or a
model's config an error is thrown. model's config an error is thrown.
forced_bos_token_id (`int`, *optional*): kwargs:
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be forwarded to the `forward` function of the model.
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
remove_invalid_values (`bool`, *optional*):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
crash. Note that using `remove_invalid_values` can slow down generation.
Return: Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
finished early due to the `eos_token_id`. finished early due to the `eos_token_id`.
""" """
# Handle `generation_config` and kwargs that might update it
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# set default parameters # set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
max_length = max_length if max_length is not None else self.config.max_length
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
use_cache = use_cache if use_cache is not None else self.config.use_cache
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.config.generator.decoder_start_token_id
)
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
@ -1583,8 +1504,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
input_ids = torch.full( input_ids = torch.full(
(batch_size * num_beams, 1), (batch_size * generation_config.num_beams, 1),
decoder_start_token_id, generation_config.decoder_start_token_id,
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
@ -1600,10 +1521,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:]) return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
# correctly extend last_hidden_state and attention mask # correctly extend last_hidden_state and attention mask
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams) encoder_outputs["last_hidden_state"] = extend_enc_output(
last_hidden_state, num_beams=generation_config.num_beams
)
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
# define start_len & additional parameters # define start_len & additional parameters
model_kwargs["doc_scores"] = doc_scores model_kwargs["doc_scores"] = doc_scores
@ -1612,64 +1535,51 @@ class RagTokenForGeneration(RagPreTrainedModel):
model_kwargs["n_docs"] = n_docs model_kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor( pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, generation_config=generation_config,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=context_input_ids, encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor, logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
) )
if num_beams == 1: if generation_config.num_beams == 1:
if num_return_sequences > 1: if generation_config.num_return_sequences > 1:
raise ValueError( raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search."
) )
return self.greedy_search( return self.greedy_search(
input_ids, input_ids,
logits_processor=pre_processor, logits_processor=pre_processor,
max_length=max_length, max_length=generation_config.max_length,
pad_token_id=pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=eos_token_id, eos_token_id=generation_config.eos_token_id,
**model_kwargs, **model_kwargs,
) )
elif num_beams > 1: elif generation_config.num_beams > 1:
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty if generation_config.num_return_sequences > generation_config.num_beams:
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=generation_config.num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=early_stopping, do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
) )
return self.beam_search( return self.beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=pre_processor, logits_processor=pre_processor,
max_length=max_length, max_length=generation_config.max_length,
pad_token_id=pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=eos_token_id, eos_token_id=generation_config.eos_token_id,
**model_kwargs, **model_kwargs,
) )
else: else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}") raise ValueError(
f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings() return self.rag.generator.get_input_embeddings()

View File

@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import tempfile import tempfile
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
from transformers.generation import GenerationConfig from transformers import AutoConfig, GenerationConfig
class LogitsProcessorTest(unittest.TestCase): class LogitsProcessorTest(unittest.TestCase):
@ -43,3 +44,33 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertEqual(loaded_config.top_k, 50) self.assertEqual(loaded_config.top_k, 50)
self.assertEqual(loaded_config.max_length, 20) self.assertEqual(loaded_config.max_length, 20)
self.assertEqual(loaded_config.max_time, None) self.assertEqual(loaded_config.max_time, None)
def test_from_model_config(self):
model_config = AutoConfig.from_pretrained("gpt2")
generation_config_from_model = GenerationConfig.from_model_config(model_config)
default_generation_config = GenerationConfig()
# The generation config has loaded a few non-default parameters from the model config
self.assertNotEqual(generation_config_from_model, default_generation_config)
# One of those parameters is eos_token_id -- check if it matches
self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id)
self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id)
def test_update(self):
generation_config = GenerationConfig()
update_kwargs = {
"max_new_tokens": 1024,
"foo": "bar",
}
update_kwargs_copy = copy.deepcopy(update_kwargs)
unused_kwargs = generation_config.update(**update_kwargs)
# update_kwargs was not modified (no side effects)
self.assertEqual(update_kwargs, update_kwargs_copy)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(generation_config.max_new_tokens, 1024)
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})