[generation] Less verbose warnings by default (#38179)

* tmp commit (imports broken)

* working version; update tests

* remove line break

* shorter msg

* dola checks need num_beams=1; other minor PR comments

* update early trainer failing on bad gen config

* make fixup

* test msg
This commit is contained in:
Joao Gante 2025-05-19 11:03:37 +01:00 committed by GitHub
parent 656e2eab3f
commit dbb9813dff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 251 additions and 247 deletions

View File

@ -35,6 +35,7 @@ from ..utils import (
is_torch_available,
logging,
)
from ..utils.deprecation import deprecate_kwarg
if TYPE_CHECKING:
@ -514,7 +515,7 @@ class GenerationConfig(PushToHubMixin):
raise err
# Validate the values of the attributes
self.validate(is_init=True)
self.validate()
def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))
@ -592,7 +593,8 @@ class GenerationConfig(PushToHubMixin):
)
return generation_mode
def validate(self, is_init=False):
@deprecate_kwarg("is_init", version="4.54.0")
def validate(self, strict=False):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
@ -600,174 +602,24 @@ class GenerationConfig(PushToHubMixin):
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
other inputs and/or the model, such as parameters related to the generation length.
Arg:
is_init (`bool`, *optional*, defaults to `False`):
Whether the validation is performed during the initialization of the instance.
Args:
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
"""
minor_issues = {} # format: {attribute_name: issue_description}
# Validation of individual attributes
# 1. Validation of individual attributes
# 1.1. Decoding attributes
if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn(
minor_issues["pad_token_id"] = (
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
"generating, if there is padding. Please set `pad_token_id` explicitly as "
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
)
# Validation of attribute relations:
fix_location = ""
if is_init:
fix_location = (
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature is not None and self.temperature != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p is not None and self.top_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.min_p is not None:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
UserWarning,
)
if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
)
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams is None:
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
self.num_beams = 1
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
)
if self.early_stopping is not False:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
),
UserWarning,
)
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
),
UserWarning,
)
if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
)
if self.constraints is not None:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
UserWarning,
)
# 3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
warnings.warn(
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
UserWarning,
)
# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
# 5. check cache-related arguments
# 1.2. Cache attributes
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
raise ValueError(
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
@ -784,6 +636,141 @@ class GenerationConfig(PushToHubMixin):
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
# 1.3. Performance attributes
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 1.4. Watermarking attributes
if self.watermarking_config is not None:
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
minor_issues["watermarking_config"] = (
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 2. Validation of attribute combinations
# 2.1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
)
if self.temperature is not None and self.temperature != 1.0:
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
flag_name="temperature", flag_value=self.temperature
)
if self.top_p is not None and self.top_p != 1.0:
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
if self.min_p is not None:
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
if self.typical_p is not None and self.typical_p != 1.0:
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
flag_name="typical_p", flag_value=self.typical_p
)
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="eta_cutoff", flag_value=self.eta_cutoff
)
# 2.2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
)
if self.early_stopping is not False:
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
flag_name="early_stopping", flag_value=self.early_stopping
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
)
if self.length_penalty is not None and self.length_penalty != 1.0:
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="length_penalty", flag_value=self.length_penalty
)
if self.constraints is not None:
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
flag_name="constraints", flag_value=self.constraints
)
# DoLa generation needs num_beams == 1
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
minor_issues["repetition_penalty"] = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
)
# 2.3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# 2.4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
# 2.5. check cache-related arguments
if self.use_cache is False:
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
@ -794,42 +781,20 @@ class GenerationConfig(PushToHubMixin):
)
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None:
logger.warning_once(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name))
minor_issues[arg_name] = no_cache_warning.format(
cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
)
# 6. check watermarking arguments
if self.watermarking_config is not None:
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
warnings.warn(
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
FutureWarning,
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 7. performances arguments
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 8. other incorrect combinations
# 2.6. other incorrect combinations
if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True:
warnings.warn(
minor_issues[extra_output_flag] = (
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
UserWarning,
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
)
# 8. check common issue: passing `generate` arguments inside the generation config
# 3. Check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
@ -839,6 +804,7 @@ class GenerationConfig(PushToHubMixin):
"streamer",
"negative_prompt_ids",
"negative_prompt_attention_mask",
"use_model_defaults",
)
for arg in generate_arguments:
if hasattr(self, arg):
@ -847,6 +813,30 @@ class GenerationConfig(PushToHubMixin):
"`generate()` (or a pipeline) directly."
)
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
if len(minor_issues) > 0:
# Full list of issues with potential fixes
info_message = []
for attribute_name, issue_description in minor_issues.items():
info_message.append(f"- `{attribute_name}`: {issue_description}")
info_message = "\n".join(info_message)
info_message += (
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
"model's `generation_config.json` file."
)
if strict:
raise ValueError("GenerationConfig is invalid: \n" + info_message)
else:
attributes_with_issues = list(minor_issues.keys())
warning_message = (
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
)
if logger.getEffectiveLevel() >= logging.WARNING:
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
logger.warning(warning_message)
logger.info(info_message)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@ -871,18 +861,13 @@ class GenerationConfig(PushToHubMixin):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
# At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
# refuse to save the instance.
# This strictness is enforced to prevent bad configurations from being saved and re-used.
try:
with warnings.catch_warnings(record=True) as caught_warnings:
self.validate()
if len(caught_warnings) > 0:
raise ValueError(str([w.message for w in caught_warnings]))
self.validate(strict=True)
except ValueError as exc:
raise ValueError(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
)
raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
use_auth_token = kwargs.pop("use_auth_token", None)

View File

@ -13,7 +13,6 @@
# limitations under the License.
import contextlib
import warnings
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@ -129,15 +128,10 @@ class Seq2SeqTrainer(Trainer):
# Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
# an exception if there are warnings at validation time.
try:
with warnings.catch_warnings(record=True) as caught_warnings:
gen_config.validate()
if len(caught_warnings) > 0:
raise ValueError(str([w.message for w in caught_warnings]))
gen_config.validate(strict=True)
except ValueError as exc:
raise ValueError(
"The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings "
"and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc)
)
raise ValueError(str(exc) + "\n\nFix these issues to train your model.")
return gen_config
def evaluate(

View File

@ -13,6 +13,7 @@
# limitations under the License.
import copy
import logging
import os
import tempfile
import unittest
@ -22,6 +23,7 @@ from huggingface_hub import HfFolder, create_pull_request
from parameterized import parameterized
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
from transformers import logging as transformers_logging
if is_torch_available():
@ -55,7 +57,14 @@ from transformers.generation import (
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, torch_device
from transformers.testing_utils import (
TOKEN,
CaptureLogger,
LoggingLevel,
TemporaryHubRepo,
is_staging_test,
torch_device,
)
class GenerationConfigTest(unittest.TestCase):
@ -112,24 +121,6 @@ class GenerationConfigTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})
# TODO: @Arthur and/or @Joao
# FAILED tests/generation/test_configuration_utils.py::GenerationConfigTest::test_initialize_new_kwargs - AttributeError: 'GenerationConfig' object has no attribute 'get_text_config'
# See: https://app.circleci.com/pipelines/github/huggingface/transformers/104831/workflows/e5e61514-51b7-4c8c-bba7-3c4d2986956e/jobs/1394252
@unittest.skip("failed with `'GenerationConfig' object has no attribute 'get_text_config'`")
def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
def test_kwarg_init(self):
"""Tests that we can overwrite attributes at `from_pretrained` time."""
default_config = GenerationConfig()
@ -159,38 +150,39 @@ class GenerationConfigTest(unittest.TestCase):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
# A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
GenerationConfig(return_dict_in_generate=False, output_scores=True)
self.assertEqual(len(captured_warnings), 1)
self.assertNotEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
self.assertNotEqual(len(captured_logs.out), 0)
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
self.assertNotEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
self.assertEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
# Impossible sets of constraints/parameters will raise an exception
with self.assertRaises(ValueError):
@ -206,9 +198,32 @@ class GenerationConfigTest(unittest.TestCase):
GenerationConfig(logits_processor="foo")
# Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
# By default we throw a short warning. However, we log with INFO level the details.
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertNotIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) < 150) # short log
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
# INFO level: we share the full deets
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) > 400) # long log
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
generation_config = GenerationConfig()
generation_config.temperature = 0.5
generation_config.do_sample = False
with self.assertRaises(ValueError):
generation_config.validate(strict=True)
def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""
@ -221,6 +236,7 @@ class GenerationConfigTest(unittest.TestCase):
with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir)
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
@ -231,15 +247,24 @@ class GenerationConfigTest(unittest.TestCase):
with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir)
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue(
"Greedy methods without beam search do not support `num_return_sequences` different than 1"
in str(exc.exception)
)
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# final check: no warnings/exceptions thrown if it is correct, and file is saved
# Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
# Catch warnings
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
# Catch logs (up to WARNING level, the default level)
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
with CaptureLogger(logger) as captured_logs:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
self.assertEqual(len(captured_logs.out), 0)
self.assertEqual(len(os.listdir(tmp_dir)), 1)
def test_generation_mode(self):
"""Tests that the `get_generation_mode` method is working as expected."""

View File

@ -202,4 +202,4 @@ class Seq2seqTrainerTester(TestCasePlus):
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))
self.assertIn("Fix these issues to train your model", str(exc.exception))