From dbb9813dff3d5dc4784ad3c7116cba96ccee16b6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 19 May 2025 11:03:37 +0100 Subject: [PATCH] [generation] Less verbose warnings by default (#38179) * tmp commit (imports broken) * working version; update tests * remove line break * shorter msg * dola checks need num_beams=1; other minor PR comments * update early trainer failing on bad gen config * make fixup * test msg --- .../generation/configuration_utils.py | 377 +++++++++--------- src/transformers/trainer_seq2seq.py | 12 +- tests/generation/test_configuration_utils.py | 107 +++-- tests/trainer/test_trainer_seq2seq.py | 2 +- 4 files changed, 251 insertions(+), 247 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 12f612bd26c..4e0d658f981 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -35,6 +35,7 @@ from ..utils import ( is_torch_available, logging, ) +from ..utils.deprecation import deprecate_kwarg if TYPE_CHECKING: @@ -514,7 +515,7 @@ class GenerationConfig(PushToHubMixin): raise err # Validate the values of the attributes - self.validate(is_init=True) + self.validate() def __hash__(self): return hash(self.to_json_string(ignore_metadata=True)) @@ -592,7 +593,8 @@ class GenerationConfig(PushToHubMixin): ) return generation_mode - def validate(self, is_init=False): + @deprecate_kwarg("is_init", version="4.54.0") + def validate(self, strict=False): """ Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence of parameterization that can be detected as incorrect from the configuration instance alone. @@ -600,174 +602,24 @@ class GenerationConfig(PushToHubMixin): Note that some parameters not validated here are best validated at generate runtime, as they may depend on other inputs and/or the model, such as parameters related to the generation length. - Arg: - is_init (`bool`, *optional*, defaults to `False`): - Whether the validation is performed during the initialization of the instance. + Args: + strict (bool): If True, raise an exception for any issues found. If False, only log issues. """ + minor_issues = {} # format: {attribute_name: issue_description} - # Validation of individual attributes + # 1. Validation of individual attributes + # 1.1. Decoding attributes if self.early_stopping not in {True, False, "never"}: raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") if self.max_new_tokens is not None and self.max_new_tokens <= 0: raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") if self.pad_token_id is not None and self.pad_token_id < 0: - warnings.warn( + minor_issues["pad_token_id"] = ( f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " "generating, if there is padding. Please set `pad_token_id` explicitly as " "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation" ) - - # Validation of attribute relations: - fix_location = "" - if is_init: - fix_location = ( - " This was detected when initializing the generation config instance, which means the corresponding " - "file may hold incorrect parameterization and should be fixed." - ) - - # 1. detect sampling-only parameterization when not in sampling mode - if self.do_sample is False: - greedy_wrong_parameter_msg = ( - "`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only " - "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`." - + fix_location - ) - if self.temperature is not None and self.temperature != 1.0: - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature), - UserWarning, - ) - if self.top_p is not None and self.top_p != 1.0: - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p), - UserWarning, - ) - if self.min_p is not None: - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p), - UserWarning, - ) - if self.typical_p is not None and self.typical_p != 1.0: - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p), - UserWarning, - ) - if ( - self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None - ): # contrastive search uses top_k - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k), - UserWarning, - ) - if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0: - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff), - UserWarning, - ) - if self.eta_cutoff is not None and self.eta_cutoff != 0.0: - warnings.warn( - greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff), - UserWarning, - ) - - # 2. detect beam-only parameterization when not in beam mode - if self.num_beams is None: - warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning) - self.num_beams = 1 - - if self.num_beams == 1: - single_beam_wrong_parameter_msg = ( - "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used " - "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location - ) - if self.early_stopping is not False: - warnings.warn( - single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping), - UserWarning, - ) - if self.num_beam_groups is not None and self.num_beam_groups != 1: - warnings.warn( - single_beam_wrong_parameter_msg.format( - flag_name="num_beam_groups", flag_value=self.num_beam_groups - ), - UserWarning, - ) - if self.diversity_penalty is not None and self.diversity_penalty != 0.0: - warnings.warn( - single_beam_wrong_parameter_msg.format( - flag_name="diversity_penalty", flag_value=self.diversity_penalty - ), - UserWarning, - ) - if self.length_penalty is not None and self.length_penalty != 1.0: - warnings.warn( - single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty), - UserWarning, - ) - if self.constraints is not None: - warnings.warn( - single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints), - UserWarning, - ) - - # 3. detect incorrect parameterization specific to advanced beam modes - else: - # constrained beam search - if self.constraints is not None or self.force_words_ids is not None: - constrained_wrong_parameter_msg = ( - "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, " - "`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set " - "`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location - ) - if self.do_sample is True: - raise ValueError( - constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample) - ) - if self.num_beam_groups is not None and self.num_beam_groups != 1: - raise ValueError( - constrained_wrong_parameter_msg.format( - flag_name="num_beam_groups", flag_value=self.num_beam_groups - ) - ) - # group beam search - if self.diversity_penalty != 0.0 or self.num_beam_groups != 1: - group_error_prefix = ( - "`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In " - "this generation mode, " - ) - if self.do_sample is True: - raise ValueError(group_error_prefix + "`do_sample` must be set to `False`") - if self.num_beams % self.num_beam_groups != 0: - raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`") - if self.diversity_penalty == 0.0: - raise ValueError( - group_error_prefix - + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." - ) - # DoLa generation - if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): - warnings.warn( - "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of " - f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for " - "DoLa decoding is `repetition_penalty>=1.2`.", - UserWarning, - ) - - # 4. check `num_return_sequences` - if self.num_return_sequences != 1: - if self.num_beams == 1: - if self.do_sample is False: - raise ValueError( - "Greedy methods without beam search do not support `num_return_sequences` different than 1 " - f"(got {self.num_return_sequences})." - ) - elif self.num_return_sequences > self.num_beams: - raise ValueError( - f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` " - f"({self.num_beams})." - ) - - # 5. check cache-related arguments + # 1.2. Cache attributes if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS: raise ValueError( f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: " @@ -784,6 +636,141 @@ class GenerationConfig(PushToHubMixin): if not isinstance(self.cache_config, cache_class): self.cache_config = cache_class.from_dict(self.cache_config) self.cache_config.validate() + # 1.3. Performance attributes + if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig): + raise ValueError( + f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an " + "instance of `CompileConfig`." + ) + # 1.4. Watermarking attributes + if self.watermarking_config is not None: + if not ( + isinstance(self.watermarking_config, WatermarkingConfig) + or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig) + ): + minor_issues["watermarking_config"] = ( + "`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct " + "`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class." + ) + self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) + self.watermarking_config.validate() + + # 2. Validation of attribute combinations + # 2.1. detect sampling-only parameterization when not in sampling mode + if self.do_sample is False: + greedy_wrong_parameter_msg = ( + "`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only " + "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`." + ) + if self.temperature is not None and self.temperature != 1.0: + minor_issues["temperature"] = greedy_wrong_parameter_msg.format( + flag_name="temperature", flag_value=self.temperature + ) + if self.top_p is not None and self.top_p != 1.0: + minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p) + if self.min_p is not None: + minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p) + if self.typical_p is not None and self.typical_p != 1.0: + minor_issues["typical_p"] = greedy_wrong_parameter_msg.format( + flag_name="typical_p", flag_value=self.typical_p + ) + if ( + self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None + ): # contrastive search uses top_k + minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k) + if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0: + minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format( + flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff + ) + if self.eta_cutoff is not None and self.eta_cutoff != 0.0: + minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format( + flag_name="eta_cutoff", flag_value=self.eta_cutoff + ) + + # 2.2. detect beam-only parameterization when not in beam mode + if self.num_beams == 1: + single_beam_wrong_parameter_msg = ( + "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used " + "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + ) + if self.early_stopping is not False: + minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format( + flag_name="early_stopping", flag_value=self.early_stopping + ) + if self.num_beam_groups is not None and self.num_beam_groups != 1: + minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format( + flag_name="num_beam_groups", flag_value=self.num_beam_groups + ) + if self.diversity_penalty is not None and self.diversity_penalty != 0.0: + minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format( + flag_name="diversity_penalty", flag_value=self.diversity_penalty + ) + if self.length_penalty is not None and self.length_penalty != 1.0: + minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format( + flag_name="length_penalty", flag_value=self.length_penalty + ) + if self.constraints is not None: + minor_issues["constraints"] = single_beam_wrong_parameter_msg.format( + flag_name="constraints", flag_value=self.constraints + ) + # DoLa generation needs num_beams == 1 + if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): + minor_issues["repetition_penalty"] = ( + "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of " + f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for " + "DoLa decoding is `repetition_penalty>=1.2`.", + ) + + # 2.3. detect incorrect parameterization specific to advanced beam modes + else: + # constrained beam search + if self.constraints is not None or self.force_words_ids is not None: + constrained_wrong_parameter_msg = ( + "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. " + "However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation " + "mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + ) + if self.do_sample is True: + raise ValueError( + constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample) + ) + if self.num_beam_groups is not None and self.num_beam_groups != 1: + raise ValueError( + constrained_wrong_parameter_msg.format( + flag_name="num_beam_groups", flag_value=self.num_beam_groups + ) + ) + # group beam search + elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1: + group_error_prefix = ( + "`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In " + "this generation mode, " + ) + if self.do_sample is True: + raise ValueError(group_error_prefix + "`do_sample` must be set to `False`") + if self.num_beams % self.num_beam_groups != 0: + raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`") + if self.diversity_penalty == 0.0: + raise ValueError( + group_error_prefix + + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." + ) + + # 2.4. check `num_return_sequences` + if self.num_return_sequences != 1: + if self.num_beams == 1: + if self.do_sample is False: + raise ValueError( + "Greedy methods without beam search do not support `num_return_sequences` different than 1 " + f"(got {self.num_return_sequences})." + ) + elif self.num_return_sequences > self.num_beams: + raise ValueError( + f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` " + f"({self.num_beams})." + ) + + # 2.5. check cache-related arguments if self.use_cache is False: # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error @@ -794,42 +781,20 @@ class GenerationConfig(PushToHubMixin): ) for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): if getattr(self, arg_name) is not None: - logger.warning_once( - no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)) + minor_issues[arg_name] = no_cache_warning.format( + cache_arg=arg_name, cache_arg_value=getattr(self, arg_name) ) - # 6. check watermarking arguments - if self.watermarking_config is not None: - if not ( - isinstance(self.watermarking_config, WatermarkingConfig) - or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig) - ): - warnings.warn( - "`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with " - "`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.", - FutureWarning, - ) - self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) - self.watermarking_config.validate() - - # 7. performances arguments - if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig): - raise ValueError( - f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an " - "instance of `CompileConfig`." - ) - - # 8. other incorrect combinations + # 2.6. other incorrect combinations if self.return_dict_in_generate is not True: for extra_output_flag in self.extra_output_flags: if getattr(self, extra_output_flag) is True: - warnings.warn( + minor_issues[extra_output_flag] = ( f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When " - f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.", - UserWarning, + f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored." ) - # 8. check common issue: passing `generate` arguments inside the generation config + # 3. Check common issue: passing `generate` arguments inside the generation config generate_arguments = ( "logits_processor", "stopping_criteria", @@ -839,6 +804,7 @@ class GenerationConfig(PushToHubMixin): "streamer", "negative_prompt_ids", "negative_prompt_attention_mask", + "use_model_defaults", ) for arg in generate_arguments: if hasattr(self, arg): @@ -847,6 +813,30 @@ class GenerationConfig(PushToHubMixin): "`generate()` (or a pipeline) directly." ) + # Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning. + if len(minor_issues) > 0: + # Full list of issues with potential fixes + info_message = [] + for attribute_name, issue_description in minor_issues.items(): + info_message.append(f"- `{attribute_name}`: {issue_description}") + info_message = "\n".join(info_message) + info_message += ( + "\nIf you're using a pretrained model, note that some of these attributes may be set through the " + "model's `generation_config.json` file." + ) + + if strict: + raise ValueError("GenerationConfig is invalid: \n" + info_message) + else: + attributes_with_issues = list(minor_issues.keys()) + warning_message = ( + f"The following generation flags are not valid and may be ignored: {attributes_with_issues}." + ) + if logger.getEffectiveLevel() >= logging.WARNING: + warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details." + logger.warning(warning_message) + logger.info(info_message) + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -871,18 +861,13 @@ class GenerationConfig(PushToHubMixin): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ - # At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance. + # At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we + # refuse to save the instance. # This strictness is enforced to prevent bad configurations from being saved and re-used. try: - with warnings.catch_warnings(record=True) as caught_warnings: - self.validate() - if len(caught_warnings) > 0: - raise ValueError(str([w.message for w in caught_warnings])) + self.validate(strict=True) except ValueError as exc: - raise ValueError( - "The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. " - "Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc) - ) + raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.") use_auth_token = kwargs.pop("use_auth_token", None) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index e9fa797f062..e3ae588cd04 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib -import warnings from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -129,15 +128,10 @@ class Seq2SeqTrainer(Trainer): # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws # an exception if there are warnings at validation time. try: - with warnings.catch_warnings(record=True) as caught_warnings: - gen_config.validate() - if len(caught_warnings) > 0: - raise ValueError(str([w.message for w in caught_warnings])) + gen_config.validate(strict=True) except ValueError as exc: - raise ValueError( - "The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings " - "and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc) - ) + raise ValueError(str(exc) + "\n\nFix these issues to train your model.") + return gen_config def evaluate( diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index f4f9764f48a..8a280b9e312 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import logging import os import tempfile import unittest @@ -22,6 +23,7 @@ from huggingface_hub import HfFolder, create_pull_request from parameterized import parameterized from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available +from transformers import logging as transformers_logging if is_torch_available(): @@ -55,7 +57,14 @@ from transformers.generation import ( UnbatchedClassifierFreeGuidanceLogitsProcessor, WatermarkLogitsProcessor, ) -from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, torch_device +from transformers.testing_utils import ( + TOKEN, + CaptureLogger, + LoggingLevel, + TemporaryHubRepo, + is_staging_test, + torch_device, +) class GenerationConfigTest(unittest.TestCase): @@ -112,24 +121,6 @@ class GenerationConfigTest(unittest.TestCase): # `.update()` returns a dictionary of unused kwargs self.assertEqual(unused_kwargs, {"foo": "bar"}) - # TODO: @Arthur and/or @Joao - # FAILED tests/generation/test_configuration_utils.py::GenerationConfigTest::test_initialize_new_kwargs - AttributeError: 'GenerationConfig' object has no attribute 'get_text_config' - # See: https://app.circleci.com/pipelines/github/huggingface/transformers/104831/workflows/e5e61514-51b7-4c8c-bba7-3c4d2986956e/jobs/1394252 - @unittest.skip("failed with `'GenerationConfig' object has no attribute 'get_text_config'`") - def test_initialize_new_kwargs(self): - generation_config = GenerationConfig() - generation_config.foo = "bar" - - with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: - generation_config.save_pretrained(tmp_dir) - - new_config = GenerationConfig.from_pretrained(tmp_dir) - # update_kwargs was used to update the config on valid attributes - self.assertEqual(new_config.foo, "bar") - - generation_config = GenerationConfig.from_model_config(new_config) - assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config - def test_kwarg_init(self): """Tests that we can overwrite attributes at `from_pretrained` time.""" default_config = GenerationConfig() @@ -159,38 +150,39 @@ class GenerationConfigTest(unittest.TestCase): """ Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time """ + logger = transformers_logging.get_logger("transformers.generation.configuration_utils") + # A correct configuration will not throw any warning - with warnings.catch_warnings(record=True) as captured_warnings: + with CaptureLogger(logger) as captured_logs: GenerationConfig() - self.assertEqual(len(captured_warnings), 0) + self.assertEqual(len(captured_logs.out), 0) # Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling # parameters with `do_sample=False`). May be escalated to an error in the future. - with warnings.catch_warnings(record=True) as captured_warnings: - GenerationConfig(do_sample=False, temperature=0.5) - self.assertEqual(len(captured_warnings), 1) - - with warnings.catch_warnings(record=True) as captured_warnings: + with CaptureLogger(logger) as captured_logs: GenerationConfig(return_dict_in_generate=False, output_scores=True) - self.assertEqual(len(captured_warnings), 1) + self.assertNotEqual(len(captured_logs.out), 0) + + with CaptureLogger(logger) as captured_logs: + generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later + self.assertNotEqual(len(captured_logs.out), 0) # Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally, # that is done by unsetting the parameter (i.e. setting it to None) - generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) - with warnings.catch_warnings(record=True) as captured_warnings: + with CaptureLogger(logger) as captured_logs: # BAD - 0.9 means it is still set, we should warn generation_config_bad_temperature.update(temperature=0.9) - self.assertEqual(len(captured_warnings), 1) - generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) - with warnings.catch_warnings(record=True) as captured_warnings: + self.assertNotEqual(len(captured_logs.out), 0) + + with CaptureLogger(logger) as captured_logs: # CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn generation_config_bad_temperature.update(temperature=1.0) - self.assertEqual(len(captured_warnings), 0) - generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) - with warnings.catch_warnings(record=True) as captured_warnings: + self.assertEqual(len(captured_logs.out), 0) + + with CaptureLogger(logger) as captured_logs: # OK - None means it is unset, nothing to warn about generation_config_bad_temperature.update(temperature=None) - self.assertEqual(len(captured_warnings), 0) + self.assertEqual(len(captured_logs.out), 0) # Impossible sets of constraints/parameters will raise an exception with self.assertRaises(ValueError): @@ -206,9 +198,32 @@ class GenerationConfigTest(unittest.TestCase): GenerationConfig(logits_processor="foo") # Model-specific parameters will NOT raise an exception or a warning - with warnings.catch_warnings(record=True) as captured_warnings: + with CaptureLogger(logger) as captured_logs: GenerationConfig(foo="bar") - self.assertEqual(len(captured_warnings), 0) + self.assertEqual(len(captured_logs.out), 0) + + # By default we throw a short warning. However, we log with INFO level the details. + # Default: we don't log the incorrect input values, only a short summary. We explain how to get more details. + with CaptureLogger(logger) as captured_logs: + GenerationConfig(do_sample=False, temperature=0.5) + self.assertNotIn("0.5", captured_logs.out) + self.assertTrue(len(captured_logs.out) < 150) # short log + self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out) + + # INFO level: we share the full deets + with LoggingLevel(logging.INFO): + with CaptureLogger(logger) as captured_logs: + GenerationConfig(do_sample=False, temperature=0.5) + self.assertIn("0.5", captured_logs.out) + self.assertTrue(len(captured_logs.out) > 400) # long log + self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out) + + # Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning. + generation_config = GenerationConfig() + generation_config.temperature = 0.5 + generation_config.do_sample = False + with self.assertRaises(ValueError): + generation_config.validate(strict=True) def test_refuse_to_save(self): """Tests that we refuse to save a generation config that fails validation.""" @@ -221,6 +236,7 @@ class GenerationConfigTest(unittest.TestCase): with self.assertRaises(ValueError) as exc: config.save_pretrained(tmp_dir) self.assertTrue("Fix these issues to save the configuration." in str(exc.exception)) + self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception)) self.assertTrue(len(os.listdir(tmp_dir)) == 0) # greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is @@ -231,15 +247,24 @@ class GenerationConfigTest(unittest.TestCase): with self.assertRaises(ValueError) as exc: config.save_pretrained(tmp_dir) self.assertTrue("Fix these issues to save the configuration." in str(exc.exception)) + self.assertTrue( + "Greedy methods without beam search do not support `num_return_sequences` different than 1" + in str(exc.exception) + ) self.assertTrue(len(os.listdir(tmp_dir)) == 0) - # final check: no warnings/exceptions thrown if it is correct, and file is saved + # Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved. config = GenerationConfig() with tempfile.TemporaryDirectory() as tmp_dir: + # Catch warnings with warnings.catch_warnings(record=True) as captured_warnings: - config.save_pretrained(tmp_dir) + # Catch logs (up to WARNING level, the default level) + logger = transformers_logging.get_logger("transformers.generation.configuration_utils") + with CaptureLogger(logger) as captured_logs: + config.save_pretrained(tmp_dir) self.assertEqual(len(captured_warnings), 0) - self.assertTrue(len(os.listdir(tmp_dir)) == 1) + self.assertEqual(len(captured_logs.out), 0) + self.assertEqual(len(os.listdir(tmp_dir)), 1) def test_generation_mode(self): """Tests that the `get_generation_mode` method is working as expected.""" diff --git a/tests/trainer/test_trainer_seq2seq.py b/tests/trainer/test_trainer_seq2seq.py index 30d34e003a6..0c4716a2bce 100644 --- a/tests/trainer/test_trainer_seq2seq.py +++ b/tests/trainer/test_trainer_seq2seq.py @@ -202,4 +202,4 @@ class Seq2seqTrainerTester(TestCasePlus): data_collator=data_collator, compute_metrics=lambda x: {"samples": x[0].shape[0]}, ) - self.assertIn("The loaded generation config instance is invalid", str(exc.exception)) + self.assertIn("Fix these issues to train your model", str(exc.exception))