[generation] Less verbose warnings by default (#38179)

* tmp commit (imports broken)

* working version; update tests

* remove line break

* shorter msg

* dola checks need num_beams=1; other minor PR comments

* update early trainer failing on bad gen config

* make fixup

* test msg
This commit is contained in:
Joao Gante 2025-05-19 11:03:37 +01:00 committed by GitHub
parent 656e2eab3f
commit dbb9813dff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 251 additions and 247 deletions

View File

@ -35,6 +35,7 @@ from ..utils import (
is_torch_available, is_torch_available,
logging, logging,
) )
from ..utils.deprecation import deprecate_kwarg
if TYPE_CHECKING: if TYPE_CHECKING:
@ -514,7 +515,7 @@ class GenerationConfig(PushToHubMixin):
raise err raise err
# Validate the values of the attributes # Validate the values of the attributes
self.validate(is_init=True) self.validate()
def __hash__(self): def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True)) return hash(self.to_json_string(ignore_metadata=True))
@ -592,7 +593,8 @@ class GenerationConfig(PushToHubMixin):
) )
return generation_mode return generation_mode
def validate(self, is_init=False): @deprecate_kwarg("is_init", version="4.54.0")
def validate(self, strict=False):
""" """
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone. of parameterization that can be detected as incorrect from the configuration instance alone.
@ -600,174 +602,24 @@ class GenerationConfig(PushToHubMixin):
Note that some parameters not validated here are best validated at generate runtime, as they may depend on Note that some parameters not validated here are best validated at generate runtime, as they may depend on
other inputs and/or the model, such as parameters related to the generation length. other inputs and/or the model, such as parameters related to the generation length.
Arg: Args:
is_init (`bool`, *optional*, defaults to `False`): strict (bool): If True, raise an exception for any issues found. If False, only log issues.
Whether the validation is performed during the initialization of the instance.
""" """
minor_issues = {} # format: {attribute_name: issue_description}
# Validation of individual attributes # 1. Validation of individual attributes
# 1.1. Decoding attributes
if self.early_stopping not in {True, False, "never"}: if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
if self.max_new_tokens is not None and self.max_new_tokens <= 0: if self.max_new_tokens is not None and self.max_new_tokens <= 0:
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0: if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn( minor_issues["pad_token_id"] = (
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
"generating, if there is padding. Please set `pad_token_id` explicitly as " "generating, if there is padding. Please set `pad_token_id` explicitly as "
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation" "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
) )
# 1.2. Cache attributes
# Validation of attribute relations:
fix_location = ""
if is_init:
fix_location = (
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature is not None and self.temperature != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p is not None and self.top_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.min_p is not None:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
UserWarning,
)
if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
)
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams is None:
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
self.num_beams = 1
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
)
if self.early_stopping is not False:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
),
UserWarning,
)
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
),
UserWarning,
)
if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
)
if self.constraints is not None:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
UserWarning,
)
# 3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
warnings.warn(
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
UserWarning,
)
# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
# 5. check cache-related arguments
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS: if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
raise ValueError( raise ValueError(
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: " f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
@ -784,6 +636,141 @@ class GenerationConfig(PushToHubMixin):
if not isinstance(self.cache_config, cache_class): if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config) self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate() self.cache_config.validate()
# 1.3. Performance attributes
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 1.4. Watermarking attributes
if self.watermarking_config is not None:
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
minor_issues["watermarking_config"] = (
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 2. Validation of attribute combinations
# 2.1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
)
if self.temperature is not None and self.temperature != 1.0:
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
flag_name="temperature", flag_value=self.temperature
)
if self.top_p is not None and self.top_p != 1.0:
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
if self.min_p is not None:
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
if self.typical_p is not None and self.typical_p != 1.0:
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
flag_name="typical_p", flag_value=self.typical_p
)
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="eta_cutoff", flag_value=self.eta_cutoff
)
# 2.2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
)
if self.early_stopping is not False:
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
flag_name="early_stopping", flag_value=self.early_stopping
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
)
if self.length_penalty is not None and self.length_penalty != 1.0:
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="length_penalty", flag_value=self.length_penalty
)
if self.constraints is not None:
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
flag_name="constraints", flag_value=self.constraints
)
# DoLa generation needs num_beams == 1
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
minor_issues["repetition_penalty"] = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
)
# 2.3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# 2.4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
# 2.5. check cache-related arguments
if self.use_cache is False: if self.use_cache is False:
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
@ -794,42 +781,20 @@ class GenerationConfig(PushToHubMixin):
) )
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None: if getattr(self, arg_name) is not None:
logger.warning_once( minor_issues[arg_name] = no_cache_warning.format(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)) cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
) )
# 6. check watermarking arguments # 2.6. other incorrect combinations
if self.watermarking_config is not None:
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
warnings.warn(
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
FutureWarning,
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 7. performances arguments
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 8. other incorrect combinations
if self.return_dict_in_generate is not True: if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags: for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True: if getattr(self, extra_output_flag) is True:
warnings.warn( minor_issues[extra_output_flag] = (
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When " f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.", f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
UserWarning,
) )
# 8. check common issue: passing `generate` arguments inside the generation config # 3. Check common issue: passing `generate` arguments inside the generation config
generate_arguments = ( generate_arguments = (
"logits_processor", "logits_processor",
"stopping_criteria", "stopping_criteria",
@ -839,6 +804,7 @@ class GenerationConfig(PushToHubMixin):
"streamer", "streamer",
"negative_prompt_ids", "negative_prompt_ids",
"negative_prompt_attention_mask", "negative_prompt_attention_mask",
"use_model_defaults",
) )
for arg in generate_arguments: for arg in generate_arguments:
if hasattr(self, arg): if hasattr(self, arg):
@ -847,6 +813,30 @@ class GenerationConfig(PushToHubMixin):
"`generate()` (or a pipeline) directly." "`generate()` (or a pipeline) directly."
) )
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
if len(minor_issues) > 0:
# Full list of issues with potential fixes
info_message = []
for attribute_name, issue_description in minor_issues.items():
info_message.append(f"- `{attribute_name}`: {issue_description}")
info_message = "\n".join(info_message)
info_message += (
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
"model's `generation_config.json` file."
)
if strict:
raise ValueError("GenerationConfig is invalid: \n" + info_message)
else:
attributes_with_issues = list(minor_issues.keys())
warning_message = (
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
)
if logger.getEffectiveLevel() >= logging.WARNING:
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
logger.warning(warning_message)
logger.info(info_message)
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
@ -871,18 +861,13 @@ class GenerationConfig(PushToHubMixin):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance. # At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
# refuse to save the instance.
# This strictness is enforced to prevent bad configurations from being saved and re-used. # This strictness is enforced to prevent bad configurations from being saved and re-used.
try: try:
with warnings.catch_warnings(record=True) as caught_warnings: self.validate(strict=True)
self.validate()
if len(caught_warnings) > 0:
raise ValueError(str([w.message for w in caught_warnings]))
except ValueError as exc: except ValueError as exc:
raise ValueError( raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import warnings
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@ -129,15 +128,10 @@ class Seq2SeqTrainer(Trainer):
# Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
# an exception if there are warnings at validation time. # an exception if there are warnings at validation time.
try: try:
with warnings.catch_warnings(record=True) as caught_warnings: gen_config.validate(strict=True)
gen_config.validate()
if len(caught_warnings) > 0:
raise ValueError(str([w.message for w in caught_warnings]))
except ValueError as exc: except ValueError as exc:
raise ValueError( raise ValueError(str(exc) + "\n\nFix these issues to train your model.")
"The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings "
"and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc)
)
return gen_config return gen_config
def evaluate( def evaluate(

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import os import os
import tempfile import tempfile
import unittest import unittest
@ -22,6 +23,7 @@ from huggingface_hub import HfFolder, create_pull_request
from parameterized import parameterized from parameterized import parameterized
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
from transformers import logging as transformers_logging
if is_torch_available(): if is_torch_available():
@ -55,7 +57,14 @@ from transformers.generation import (
UnbatchedClassifierFreeGuidanceLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor, WatermarkLogitsProcessor,
) )
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, torch_device from transformers.testing_utils import (
TOKEN,
CaptureLogger,
LoggingLevel,
TemporaryHubRepo,
is_staging_test,
torch_device,
)
class GenerationConfigTest(unittest.TestCase): class GenerationConfigTest(unittest.TestCase):
@ -112,24 +121,6 @@ class GenerationConfigTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs # `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"}) self.assertEqual(unused_kwargs, {"foo": "bar"})
# TODO: @Arthur and/or @Joao
# FAILED tests/generation/test_configuration_utils.py::GenerationConfigTest::test_initialize_new_kwargs - AttributeError: 'GenerationConfig' object has no attribute 'get_text_config'
# See: https://app.circleci.com/pipelines/github/huggingface/transformers/104831/workflows/e5e61514-51b7-4c8c-bba7-3c4d2986956e/jobs/1394252
@unittest.skip("failed with `'GenerationConfig' object has no attribute 'get_text_config'`")
def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
def test_kwarg_init(self): def test_kwarg_init(self):
"""Tests that we can overwrite attributes at `from_pretrained` time.""" """Tests that we can overwrite attributes at `from_pretrained` time."""
default_config = GenerationConfig() default_config = GenerationConfig()
@ -159,38 +150,39 @@ class GenerationConfigTest(unittest.TestCase):
""" """
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
""" """
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
# A correct configuration will not throw any warning # A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings: with CaptureLogger(logger) as captured_logs:
GenerationConfig() GenerationConfig()
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_logs.out), 0)
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling # Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future. # parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings: with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(return_dict_in_generate=False, output_scores=True) GenerationConfig(return_dict_in_generate=False, output_scores=True)
self.assertEqual(len(captured_warnings), 1) self.assertNotEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
self.assertNotEqual(len(captured_logs.out), 0)
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally, # Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None) # that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) with CaptureLogger(logger) as captured_logs:
with warnings.catch_warnings(record=True) as captured_warnings:
# BAD - 0.9 means it is still set, we should warn # BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9) generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1) self.assertNotEqual(len(captured_logs.out), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings: with CaptureLogger(logger) as captured_logs:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn # CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0) generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_logs.out), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings: with CaptureLogger(logger) as captured_logs:
# OK - None means it is unset, nothing to warn about # OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None) generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_logs.out), 0)
# Impossible sets of constraints/parameters will raise an exception # Impossible sets of constraints/parameters will raise an exception
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -206,9 +198,32 @@ class GenerationConfigTest(unittest.TestCase):
GenerationConfig(logits_processor="foo") GenerationConfig(logits_processor="foo")
# Model-specific parameters will NOT raise an exception or a warning # Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings: with CaptureLogger(logger) as captured_logs:
GenerationConfig(foo="bar") GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_logs.out), 0)
# By default we throw a short warning. However, we log with INFO level the details.
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertNotIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) < 150) # short log
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
# INFO level: we share the full deets
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) > 400) # long log
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
generation_config = GenerationConfig()
generation_config.temperature = 0.5
generation_config.do_sample = False
with self.assertRaises(ValueError):
generation_config.validate(strict=True)
def test_refuse_to_save(self): def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation.""" """Tests that we refuse to save a generation config that fails validation."""
@ -221,6 +236,7 @@ class GenerationConfigTest(unittest.TestCase):
with self.assertRaises(ValueError) as exc: with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception)) self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
self.assertTrue(len(os.listdir(tmp_dir)) == 0) self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is # greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
@ -231,15 +247,24 @@ class GenerationConfigTest(unittest.TestCase):
with self.assertRaises(ValueError) as exc: with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception)) self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue(
"Greedy methods without beam search do not support `num_return_sequences` different than 1"
in str(exc.exception)
)
self.assertTrue(len(os.listdir(tmp_dir)) == 0) self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# final check: no warnings/exceptions thrown if it is correct, and file is saved # Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
config = GenerationConfig() config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
# Catch warnings
with warnings.catch_warnings(record=True) as captured_warnings: with warnings.catch_warnings(record=True) as captured_warnings:
# Catch logs (up to WARNING level, the default level)
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
with CaptureLogger(logger) as captured_logs:
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1) self.assertEqual(len(captured_logs.out), 0)
self.assertEqual(len(os.listdir(tmp_dir)), 1)
def test_generation_mode(self): def test_generation_mode(self):
"""Tests that the `get_generation_mode` method is working as expected.""" """Tests that the `get_generation_mode` method is working as expected."""

View File

@ -202,4 +202,4 @@ class Seq2seqTrainerTester(TestCasePlus):
data_collator=data_collator, data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]}, compute_metrics=lambda x: {"samples": x[0].shape[0]},
) )
self.assertIn("The loaded generation config instance is invalid", str(exc.exception)) self.assertIn("Fix these issues to train your model", str(exc.exception))