Generate: add generation config class (#20218)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Joao Gante 2022-11-21 13:30:15 +00:00 committed by GitHub
parent 8503cc7550
commit 3de07473da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 633 additions and 3 deletions

View File

@ -18,6 +18,14 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate --->
## GenerationConfig
[[autodoc]] generation.GenerationConfig
- from_pretrained
- save_pretrained
## GenerationMixin
[[autodoc]] generation.GenerationMixin

View File

@ -96,7 +96,7 @@ _import_structure = {
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": [],
"generation": ["GenerationConfig"],
"hf_argparser": ["HfArgumentParser"],
"integrations": [
"is_clearml_available",
@ -3258,6 +3258,9 @@ if TYPE_CHECKING:
# Feature Extractor
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Generation
from .generation import GenerationConfig
from .hf_argparser import HfArgumentParser
# Integrations

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {}
_import_structure = {"configuration_utils": ["GenerationConfig"]}
try:
@ -149,6 +149,8 @@ else:
]
if TYPE_CHECKING:
from .configuration_utils import GenerationConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()

View File

@ -0,0 +1,570 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Generation configuration class and utilities."""
import copy
import json
import os
from typing import Any, Dict, Optional, Union
from .. import __version__
from ..utils import (
GENERATION_CONFIG_NAME,
PushToHubMixin,
cached_file,
download_url,
extract_commit_hash,
is_remote_url,
logging,
)
logger = logging.get_logger(__name__)
class GenerationConfig(PushToHubMixin):
r"""
Class that holds a configuration for a generation task.
<Tip>
A generation configuration file can be loaded and saved to disk. Loading and using a generation configuration file
does **not** change a model configuration or weights. It only affects the model's behavior at generation time.
</Tip>
Arg:
> Parameters that control the length of the output
max_length (`int`, *optional*, defaults to 20):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in the
prompt.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 0):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed.
> Parameters that control the generation strategy used
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
> Parameters for manipulation of the model output logits
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
`top_p` or higher are kept for generation.
typical_p (`float`, *optional*, defaults to 1.0):
The amount of probability mass from the original distribution to be considered in typical decoding. If set
to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
diversity_penalty (`float`, *optional*, defaults to 0.0):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[List[int]]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
language token.
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached.
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
Note that using `remove_invalid_values` can slow down generation.
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
penalty starts and `decay_factor` represents the factor of exponential decay
suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their
log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
> Parameters that define the output variables of `generate`
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
> Special tokens that can be used at generation time
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
> Generation parameters exclusive to encoder-decoder models
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
> Wild card
generation_kwargs:
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
present in `generate`'s signature will be used in the model forward pass.
"""
def __init__(self, **kwargs):
# Parameters that control the length of the output
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None)
# Parameters that control the generation strategy used
self.do_sample = kwargs.pop("do_sample", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.force_word_ids = kwargs.pop("force_word_ids", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_scores = kwargs.pop("output_scores", False)
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
# Special tokens that can be used at generation time
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
# Generation parameters exclusive to encoder-decoder models
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface.
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
push_to_hub: bool = False,
**kwargs
):
r"""
Save a generation configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~GenerationConfig.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
Name of the generation configuration JSON file to be saved in `save_directory`.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
output_config_file = os.path.join(save_directory, config_file_name)
self.to_json_file(output_config_file, use_diff=True)
logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
**kwargs
) -> "GenerationConfig":
r"""
Instantiate a [`GenerationConfig`] from a generation configuration file.
Args:
pretrained_model_name (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
Name of the generation configuration JSON file to be loaded from `pretrained_model_name`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if
they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
</Tip>
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
If `False`, then this function returns just the final configuration object.
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
Returns:
[`GenerationConfig`]: The configuration object instantiated from this pretrained model.
Examples:
```python
>>> from transformers import GenerationConfig
>>> # Download configuration from huggingface.co and cache.
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> # E.g. config was saved using *save_pretrained('./test/saved_model/')*
>>> generation_config.save_pretrained("./test/saved_model/")
>>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/")
>>> # You can also specify configuration names to your generation configuration file
>>> generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json")
>>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")
>>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
>>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
>>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
... "gpt2", top_k=1, foo=False, return_unused_kwargs=True
... )
>>> generation_config.top_k
1
>>> unused_kwargs
{'foo': False}
```"""
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
config_path = os.path.join(pretrained_model_name, config_file_name)
config_path = str(config_path)
is_local = os.path.exists(config_path)
if os.path.isfile(os.path.join(subfolder, config_path)):
# Special case when config_path is a local file
resolved_config_file = config_path
is_local = True
elif is_remote_url(config_path):
configuration_file = config_path
resolved_config_file = download_url(config_path)
else:
configuration_file = config_file_name
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
pretrained_model_name,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the configuration of '{pretrained_model_name}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name}' is the correct path to a directory"
f" containing a {configuration_file} file"
)
try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)
if is_local:
logger.info(f"loading configuration file {resolved_config_file}")
else:
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
return cls.from_dict(config_dict, **kwargs)
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Those arguments may be passed along for our internal telemetry.
# We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, kwargs
else:
return config
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
# get the default config dict
default_config_dict = GenerationConfig().to_dict()
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
serializable_config_dict[key] = value
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
if "_commit_hash" in output:
del output["_commit_hash"]
# Transformers version when serializing this file
output["transformers_version"] = __version__
return output
def to_json_string(self, use_diff: bool = True) -> str:
"""
Serializes this instance to a JSON string.
Args:
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
if use_diff is True:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))

View File

@ -178,6 +178,7 @@ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
GENERATION_CONFIG_NAME = "generation_config.json"
MODEL_CARD_NAME = "modelcard.json"
SENTENCEPIECE_UNDERLINE = ""

View File

@ -0,0 +1,45 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from parameterized import parameterized
from transformers.generation import GenerationConfig
class LogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(None,), ("foo.json",)])
def test_save_load_config(self, config_name):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
bad_words_ids=[[1, 2, 3], [4, 5]],
)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, config_name=config_name)
loaded_config = GenerationConfig.from_pretrained(tmp_dir, config_name=config_name)
# Checks parameters that were specified
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.temperature, 0.7)
self.assertEqual(loaded_config.length_penalty, 1.0)
self.assertEqual(loaded_config.bad_words_ids, [[1, 2, 3], [4, 5]])
# Checks parameters that were not specified (defaults)
self.assertEqual(loaded_config.top_k, 50)
self.assertEqual(loaded_config.max_length, 20)
self.assertEqual(loaded_config.max_time, None)

View File

@ -12,8 +12,9 @@ docs/source/en/model_doc/byt5.mdx
docs/source/en/model_doc/tapex.mdx
docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation/utils.py
src/transformers/generation/configuration_utils.py
src/transformers/generation/tf_utils.py
src/transformers/generation/utils.py
src/transformers/models/albert/configuration_albert.py
src/transformers/models/albert/modeling_albert.py
src/transformers/models/albert/modeling_tf_albert.py