mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: use GenerationConfig
as the basis for .generate()
parametrization (#20388)
* generate from config mvp * fix failing tests * max_time test * Load default gen config at model load time; Update docs * further documentation; add tests * adapt rag to the new structure * handle models not instantiated with from_pretained (like in tests) * better default generation config * add can_generate fn * handle legacy use case of ad hoc model config changes * initialize gen config from config in individual methods, if gen config is none * fix _get_decoder_start_token_id when called outside GenerationMixin * correct model config load order (set attr > model config > decoder config) * update rag to match latest changes * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * load gen config from model config in model.from_pretrained * fix can_generate fn * handle generate calls without a previous from_pretrained (e.g. tests) * add legacy behavior (and a warning) * lower logger severity Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b1706f6908
commit
4bc723f87d
@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme
|
||||
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
|
||||
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
|
||||
|
||||
<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate --->
|
||||
Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
|
||||
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
|
||||
of the generation method.
|
||||
|
||||
All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
|
||||
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
|
||||
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
|
||||
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, GenerationConfig
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
|
||||
|
||||
# Inspect the default generation configuration
|
||||
print(model.generation_config)
|
||||
|
||||
# Set a new default generation configuration
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
|
||||
)
|
||||
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
|
||||
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
|
||||
into resource limitations. Make sure you double-check the defaults in the documentation.
|
||||
|
||||
</Tip>
|
||||
|
||||
You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
|
||||
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
|
||||
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
|
||||
other for summarization with beam search).
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||
|
||||
translation_generation_config = GenerationConfig(
|
||||
num_beams=4,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=0,
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
pad_token=model.config.pad_token_id,
|
||||
)
|
||||
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
|
||||
# config as follows
|
||||
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
|
||||
|
||||
# You could then use the named generation config file to parameterize generation
|
||||
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
|
||||
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
# ['Les fichiers de configuration sont faciles à utiliser !']
|
||||
```
|
||||
|
||||
Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
|
||||
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
|
||||
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
|
||||
to parameterize it.
|
||||
|
||||
|
||||
## GenerationConfig
|
||||
|
||||
[[autodoc]] generation.GenerationConfig
|
||||
- from_pretrained
|
||||
- from_model_config
|
||||
- save_pretrained
|
||||
|
||||
## GenerationMixin
|
||||
|
@ -20,6 +20,7 @@ import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..utils import (
|
||||
GENERATION_CONFIG_NAME,
|
||||
PushToHubMixin,
|
||||
@ -36,7 +37,23 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class GenerationConfig(PushToHubMixin):
|
||||
r"""
|
||||
Class that holds a configuration for a generation task.
|
||||
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
||||
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
|
||||
and `top_k>1`
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
|
||||
`num_beams>1` and `do_sample=True`.
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
|
||||
`num_beams>1` and `num_beam_groups>1`.
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
||||
`constraints!=None` or `force_words_ids!=None`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -45,6 +62,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
</Tip>
|
||||
|
||||
Most of these parameters are explained in more detail in [this blog
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
|
||||
Arg:
|
||||
> Parameters that control the length of the output
|
||||
|
||||
@ -73,6 +93,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
penalty_alpha (`float`, *optional*):
|
||||
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
|
||||
> Parameters for manipulation of the model output logits
|
||||
|
||||
@ -108,13 +131,13 @@ class GenerationConfig(PushToHubMixin):
|
||||
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
|
||||
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
||||
can allow different forms of each word.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
|
||||
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
||||
are normalized but some logit processors or warpers break the normalization.
|
||||
constraints (`List[Constraint]`, *optional*):
|
||||
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
||||
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
||||
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
|
||||
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
||||
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
||||
@ -191,6 +214,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
||||
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
||||
self.use_cache = kwargs.pop("use_cache", True)
|
||||
|
||||
# Parameters for manipulation of the model output logits
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
@ -202,7 +226,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.force_word_ids = kwargs.pop("force_word_ids", None)
|
||||
self.force_words_ids = kwargs.pop("force_words_ids", None)
|
||||
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
|
||||
self.constraints = kwargs.pop("constraints", None)
|
||||
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
||||
@ -230,12 +256,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface.
|
||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub
|
||||
# interface.
|
||||
self._from_model_config = kwargs.pop("_from_model_config", False)
|
||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
self_dict = self.__dict__.copy()
|
||||
other_dict = other.__dict__.copy()
|
||||
# ignore metadata
|
||||
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
|
||||
self_dict.pop(metadata_field, None)
|
||||
other_dict.pop(metadata_field, None)
|
||||
return self_dict == other_dict
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
@ -484,18 +518,11 @@ class GenerationConfig(PushToHubMixin):
|
||||
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
unused_kwargs = config.update(**kwargs)
|
||||
|
||||
logger.info(f"Generate config {config}")
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
return config, unused_kwargs
|
||||
else:
|
||||
return config
|
||||
|
||||
@ -568,3 +595,54 @@ class GenerationConfig(PushToHubMixin):
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string(use_diff=use_diff))
|
||||
|
||||
@classmethod
|
||||
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
|
||||
"""
|
||||
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
|
||||
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
|
||||
|
||||
Args:
|
||||
model_config (`PretrainedConfig`):
|
||||
The model config that will be used to instantiate the generation config.
|
||||
|
||||
Returns:
|
||||
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
config_dict = model_config.to_dict()
|
||||
config = cls.from_dict(config_dict, return_unused_kwargs=False)
|
||||
|
||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||
# generation config.
|
||||
for decoder_name in ("decoder", "generator"):
|
||||
if decoder_name in config_dict:
|
||||
default_generation_config = GenerationConfig()
|
||||
decoder_config = config_dict[decoder_name]
|
||||
for attr in config.to_dict().keys():
|
||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
||||
setattr(config, attr, decoder_config[attr])
|
||||
|
||||
config._from_model_config = True
|
||||
return config
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
|
||||
returning all the unused kwargs.
|
||||
|
||||
Args:
|
||||
kwargs (`Dict[str, Any]`):
|
||||
Dictionary of attributes to tentatively update this class.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
||||
"""
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
to_remove.append(key)
|
||||
|
||||
# remove all the attributes that were updated, without modifying the input dict
|
||||
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
||||
return unused_kwargs
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -39,7 +39,7 @@ from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import GenerationMixin
|
||||
from .generation import GenerationConfig, GenerationMixin
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
@ -1024,6 +1024,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
self.warnings_issued = {}
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
@ -1106,6 +1107,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def can_generate(self) -> bool:
|
||||
"""
|
||||
Returns whether this model can generate sequences with `.generate()`.
|
||||
|
||||
Returns:
|
||||
`bool`: Whether this model can generate sequences with `.generate()`.
|
||||
"""
|
||||
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
|
||||
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
@ -2477,6 +2490,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load the generation config
|
||||
if model.can_generate():
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
pass
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""RAG model implementation."""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
@ -21,7 +22,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList
|
||||
from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
@ -1384,33 +1385,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
context_input_ids: Optional[torch.LongTensor] = None,
|
||||
context_attention_mask: Optional[torch.LongTensor] = None,
|
||||
doc_scores: Optional[torch.FloatTensor] = None,
|
||||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
num_beam_groups: Optional[int] = None,
|
||||
diversity_penalty: Optional[float] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
encoder_no_repeat_ngram_size: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
bad_words_ids: Optional[List[List[int]]] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
n_docs: Optional[int] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||
renormalize_logits: Optional[bool] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
remove_invalid_values: Optional[bool] = None,
|
||||
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
|
||||
**model_kwargs
|
||||
**kwargs
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Implements RAG token decoding.
|
||||
@ -1444,51 +1424,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
|
||||
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
|
||||
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
min_length (`int`, *optional*, defaults to 10):
|
||||
The minimum length of the sequence to be generated.
|
||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
|
||||
not.
|
||||
use_cache: (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
bos_token_id (`int`, *optional*):
|
||||
The id of the *beginning-of-sequence* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
|
||||
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
|
||||
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
|
||||
while `length_penalty` < 0.0 encourages shorter sequences.
|
||||
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
|
||||
`decoder_input_ids`.
|
||||
bad_words_ids(`List[int]`, *optional*):
|
||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
num_beam_groups (`int`, *optional*, defaults to 1):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
|
||||
beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
diversity_penalty (`float`, *optional*, defaults to 0.0):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group
|
||||
at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
|
||||
enabled.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
|
||||
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
|
||||
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
n_docs (`int`, *optional*, defaults to `config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||
passed to generate matching the attributes of `generation_config` will override them. If
|
||||
`generation_config` is not provided, the default will be used, which has the following loading
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
|
||||
@ -1497,53 +1441,30 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
Custom logits processors that complement the default logits processors built from arguments and a
|
||||
model's config. If a logit processor is passed that is already created with the arguments or a model's
|
||||
config an error is thrown.
|
||||
Custom logits processors that complement the default logits processors built from arguments and a
|
||||
model's config. If a logit processor is passed that is already created with the arguments or a model's
|
||||
config an error is thrown.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
model's config an error is thrown.
|
||||
forced_bos_token_id (`int`, *optional*):
|
||||
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
|
||||
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
|
||||
the target language token.
|
||||
forced_eos_token_id (`int`, *optional*):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
remove_invalid_values (`bool`, *optional*):
|
||||
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
|
||||
crash. Note that using `remove_invalid_values` can slow down generation.
|
||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
model's config an error is thrown.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model.
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
|
||||
finished early due to the `eos_token_id`.
|
||||
"""
|
||||
# Handle `generation_config` and kwargs that might update it
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
|
||||
# set default parameters
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id
|
||||
if decoder_start_token_id is not None
|
||||
else self.config.generator.decoder_start_token_id
|
||||
)
|
||||
remove_invalid_values = (
|
||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||
)
|
||||
exponential_decay_length_penalty = (
|
||||
exponential_decay_length_penalty
|
||||
if exponential_decay_length_penalty is not None
|
||||
else self.config.exponential_decay_length_penalty
|
||||
)
|
||||
|
||||
# retrieve docs
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
@ -1583,8 +1504,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
|
||||
|
||||
input_ids = torch.full(
|
||||
(batch_size * num_beams, 1),
|
||||
decoder_start_token_id,
|
||||
(batch_size * generation_config.num_beams, 1),
|
||||
generation_config.decoder_start_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
@ -1600,10 +1521,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
|
||||
|
||||
# correctly extend last_hidden_state and attention mask
|
||||
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
|
||||
encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams)
|
||||
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
|
||||
encoder_outputs["last_hidden_state"] = extend_enc_output(
|
||||
last_hidden_state, num_beams=generation_config.num_beams
|
||||
)
|
||||
|
||||
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
|
||||
doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
|
||||
|
||||
# define start_len & additional parameters
|
||||
model_kwargs["doc_scores"] = doc_scores
|
||||
@ -1612,64 +1535,51 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
model_kwargs["n_docs"] = n_docs
|
||||
|
||||
pre_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=context_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||
logits_processor=logits_processor,
|
||||
renormalize_logits=renormalize_logits,
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
if num_return_sequences > 1:
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||
" greedy search."
|
||||
)
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
logits_processor=pre_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=generation_config.max_length,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif num_beams > 1:
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
if num_return_sequences > num_beams:
|
||||
elif generation_config.num_beams > 1:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
num_beams=generation_config.num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
)
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=pre_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=generation_config.max_length,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.rag.generator.get_input_embeddings()
|
||||
|
@ -13,11 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers import AutoConfig, GenerationConfig
|
||||
|
||||
|
||||
class LogitsProcessorTest(unittest.TestCase):
|
||||
@ -43,3 +44,33 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertEqual(loaded_config.top_k, 50)
|
||||
self.assertEqual(loaded_config.max_length, 20)
|
||||
self.assertEqual(loaded_config.max_time, None)
|
||||
|
||||
def test_from_model_config(self):
|
||||
model_config = AutoConfig.from_pretrained("gpt2")
|
||||
generation_config_from_model = GenerationConfig.from_model_config(model_config)
|
||||
default_generation_config = GenerationConfig()
|
||||
|
||||
# The generation config has loaded a few non-default parameters from the model config
|
||||
self.assertNotEqual(generation_config_from_model, default_generation_config)
|
||||
|
||||
# One of those parameters is eos_token_id -- check if it matches
|
||||
self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id)
|
||||
self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id)
|
||||
|
||||
def test_update(self):
|
||||
generation_config = GenerationConfig()
|
||||
update_kwargs = {
|
||||
"max_new_tokens": 1024,
|
||||
"foo": "bar",
|
||||
}
|
||||
update_kwargs_copy = copy.deepcopy(update_kwargs)
|
||||
unused_kwargs = generation_config.update(**update_kwargs)
|
||||
|
||||
# update_kwargs was not modified (no side effects)
|
||||
self.assertEqual(update_kwargs, update_kwargs_copy)
|
||||
|
||||
# update_kwargs was used to update the config on valid attributes
|
||||
self.assertEqual(generation_config.max_new_tokens, 1024)
|
||||
|
||||
# `.update()` returns a dictionary of unused kwargs
|
||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||
|
Loading…
Reference in New Issue
Block a user