mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[generation] Less verbose warnings by default (#38179)
* tmp commit (imports broken) * working version; update tests * remove line break * shorter msg * dola checks need num_beams=1; other minor PR comments * update early trainer failing on bad gen config * make fixup * test msg
This commit is contained in:
parent
656e2eab3f
commit
dbb9813dff
@ -35,6 +35,7 @@ from ..utils import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -514,7 +515,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate(is_init=True)
|
||||
self.validate()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.to_json_string(ignore_metadata=True))
|
||||
@ -592,7 +593,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def validate(self, is_init=False):
|
||||
@deprecate_kwarg("is_init", version="4.54.0")
|
||||
def validate(self, strict=False):
|
||||
"""
|
||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||
@ -600,174 +602,24 @@ class GenerationConfig(PushToHubMixin):
|
||||
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
||||
other inputs and/or the model, such as parameters related to the generation length.
|
||||
|
||||
Arg:
|
||||
is_init (`bool`, *optional*, defaults to `False`):
|
||||
Whether the validation is performed during the initialization of the instance.
|
||||
Args:
|
||||
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
|
||||
"""
|
||||
minor_issues = {} # format: {attribute_name: issue_description}
|
||||
|
||||
# Validation of individual attributes
|
||||
# 1. Validation of individual attributes
|
||||
# 1.1. Decoding attributes
|
||||
if self.early_stopping not in {True, False, "never"}:
|
||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
||||
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
||||
if self.pad_token_id is not None and self.pad_token_id < 0:
|
||||
warnings.warn(
|
||||
minor_issues["pad_token_id"] = (
|
||||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
||||
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
||||
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
||||
)
|
||||
|
||||
# Validation of attribute relations:
|
||||
fix_location = ""
|
||||
if is_init:
|
||||
fix_location = (
|
||||
" This was detected when initializing the generation config instance, which means the corresponding "
|
||||
"file may hold incorrect parameterization and should be fixed."
|
||||
)
|
||||
|
||||
# 1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
+ fix_location
|
||||
)
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.min_p is not None:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
|
||||
UserWarning,
|
||||
)
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
|
||||
UserWarning,
|
||||
)
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams is None:
|
||||
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
|
||||
self.num_beams = 1
|
||||
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
|
||||
UserWarning,
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
|
||||
UserWarning,
|
||||
)
|
||||
if self.constraints is not None:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
|
||||
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
|
||||
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
# DoLa generation
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
warnings.warn(
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 5. check cache-related arguments
|
||||
# 1.2. Cache attributes
|
||||
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
|
||||
raise ValueError(
|
||||
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
|
||||
@ -784,6 +636,141 @@ class GenerationConfig(PushToHubMixin):
|
||||
if not isinstance(self.cache_config, cache_class):
|
||||
self.cache_config = cache_class.from_dict(self.cache_config)
|
||||
self.cache_config.validate()
|
||||
# 1.3. Performance attributes
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
# 1.4. Watermarking attributes
|
||||
if self.watermarking_config is not None:
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
minor_issues["watermarking_config"] = (
|
||||
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
|
||||
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 2. Validation of attribute combinations
|
||||
# 2.1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
)
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="temperature", flag_value=self.temperature
|
||||
)
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
|
||||
if self.min_p is not None:
|
||||
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="typical_p", flag_value=self.typical_p
|
||||
)
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
|
||||
)
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="eta_cutoff", flag_value=self.eta_cutoff
|
||||
)
|
||||
|
||||
# 2.2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="early_stopping", flag_value=self.early_stopping
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
)
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="length_penalty", flag_value=self.length_penalty
|
||||
)
|
||||
if self.constraints is not None:
|
||||
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="constraints", flag_value=self.constraints
|
||||
)
|
||||
# DoLa generation needs num_beams == 1
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
minor_issues["repetition_penalty"] = (
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
)
|
||||
|
||||
# 2.3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
|
||||
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
|
||||
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
|
||||
# 2.4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 2.5. check cache-related arguments
|
||||
if self.use_cache is False:
|
||||
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
||||
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
||||
@ -794,42 +781,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
)
|
||||
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
||||
if getattr(self, arg_name) is not None:
|
||||
logger.warning_once(
|
||||
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name))
|
||||
minor_issues[arg_name] = no_cache_warning.format(
|
||||
cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
|
||||
)
|
||||
|
||||
# 6. check watermarking arguments
|
||||
if self.watermarking_config is not None:
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
warnings.warn(
|
||||
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
|
||||
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 7. performances arguments
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
|
||||
# 8. other incorrect combinations
|
||||
# 2.6. other incorrect combinations
|
||||
if self.return_dict_in_generate is not True:
|
||||
for extra_output_flag in self.extra_output_flags:
|
||||
if getattr(self, extra_output_flag) is True:
|
||||
warnings.warn(
|
||||
minor_issues[extra_output_flag] = (
|
||||
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
||||
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
|
||||
UserWarning,
|
||||
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
|
||||
)
|
||||
|
||||
# 8. check common issue: passing `generate` arguments inside the generation config
|
||||
# 3. Check common issue: passing `generate` arguments inside the generation config
|
||||
generate_arguments = (
|
||||
"logits_processor",
|
||||
"stopping_criteria",
|
||||
@ -839,6 +804,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
"streamer",
|
||||
"negative_prompt_ids",
|
||||
"negative_prompt_attention_mask",
|
||||
"use_model_defaults",
|
||||
)
|
||||
for arg in generate_arguments:
|
||||
if hasattr(self, arg):
|
||||
@ -847,6 +813,30 @@ class GenerationConfig(PushToHubMixin):
|
||||
"`generate()` (or a pipeline) directly."
|
||||
)
|
||||
|
||||
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
|
||||
if len(minor_issues) > 0:
|
||||
# Full list of issues with potential fixes
|
||||
info_message = []
|
||||
for attribute_name, issue_description in minor_issues.items():
|
||||
info_message.append(f"- `{attribute_name}`: {issue_description}")
|
||||
info_message = "\n".join(info_message)
|
||||
info_message += (
|
||||
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
|
||||
"model's `generation_config.json` file."
|
||||
)
|
||||
|
||||
if strict:
|
||||
raise ValueError("GenerationConfig is invalid: \n" + info_message)
|
||||
else:
|
||||
attributes_with_issues = list(minor_issues.keys())
|
||||
warning_message = (
|
||||
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
|
||||
)
|
||||
if logger.getEffectiveLevel() >= logging.WARNING:
|
||||
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
|
||||
logger.warning(warning_message)
|
||||
logger.info(info_message)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@ -871,18 +861,13 @@ class GenerationConfig(PushToHubMixin):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
|
||||
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
|
||||
# At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
|
||||
# refuse to save the instance.
|
||||
# This strictness is enforced to prevent bad configurations from being saved and re-used.
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.validate()
|
||||
if len(caught_warnings) > 0:
|
||||
raise ValueError(str([w.message for w in caught_warnings]))
|
||||
self.validate(strict=True)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
||||
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
|
||||
)
|
||||
raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
|
||||
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
@ -129,15 +128,10 @@ class Seq2SeqTrainer(Trainer):
|
||||
# Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
|
||||
# an exception if there are warnings at validation time.
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
gen_config.validate()
|
||||
if len(caught_warnings) > 0:
|
||||
raise ValueError(str([w.message for w in caught_warnings]))
|
||||
gen_config.validate(strict=True)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings "
|
||||
"and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc)
|
||||
)
|
||||
raise ValueError(str(exc) + "\n\nFix these issues to train your model.")
|
||||
|
||||
return gen_config
|
||||
|
||||
def evaluate(
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -22,6 +23,7 @@ from huggingface_hub import HfFolder, create_pull_request
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -55,7 +57,14 @@ from transformers.generation import (
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, torch_device
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
LoggingLevel,
|
||||
TemporaryHubRepo,
|
||||
is_staging_test,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class GenerationConfigTest(unittest.TestCase):
|
||||
@ -112,24 +121,6 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
# `.update()` returns a dictionary of unused kwargs
|
||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||
|
||||
# TODO: @Arthur and/or @Joao
|
||||
# FAILED tests/generation/test_configuration_utils.py::GenerationConfigTest::test_initialize_new_kwargs - AttributeError: 'GenerationConfig' object has no attribute 'get_text_config'
|
||||
# See: https://app.circleci.com/pipelines/github/huggingface/transformers/104831/workflows/e5e61514-51b7-4c8c-bba7-3c4d2986956e/jobs/1394252
|
||||
@unittest.skip("failed with `'GenerationConfig' object has no attribute 'get_text_config'`")
|
||||
def test_initialize_new_kwargs(self):
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.foo = "bar"
|
||||
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
# update_kwargs was used to update the config on valid attributes
|
||||
self.assertEqual(new_config.foo, "bar")
|
||||
|
||||
generation_config = GenerationConfig.from_model_config(new_config)
|
||||
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
|
||||
|
||||
def test_kwarg_init(self):
|
||||
"""Tests that we can overwrite attributes at `from_pretrained` time."""
|
||||
default_config = GenerationConfig()
|
||||
@ -159,38 +150,39 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
|
||||
# A correct configuration will not throw any warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# parameters with `do_sample=False`). May be escalated to an error in the future.
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(return_dict_in_generate=False, output_scores=True)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
||||
# that is done by unsetting the parameter (i.e. setting it to None)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# BAD - 0.9 means it is still set, we should warn
|
||||
generation_config_bad_temperature.update(temperature=0.9)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
|
||||
generation_config_bad_temperature.update(temperature=1.0)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# OK - None means it is unset, nothing to warn about
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Impossible sets of constraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
@ -206,9 +198,32 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Model-specific parameters will NOT raise an exception or a warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# By default we throw a short warning. However, we log with INFO level the details.
|
||||
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertNotIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) < 150) # short log
|
||||
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# INFO level: we share the full deets
|
||||
with LoggingLevel(logging.INFO):
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) > 400) # long log
|
||||
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.temperature = 0.5
|
||||
generation_config.do_sample = False
|
||||
with self.assertRaises(ValueError):
|
||||
generation_config.validate(strict=True)
|
||||
|
||||
def test_refuse_to_save(self):
|
||||
"""Tests that we refuse to save a generation config that fails validation."""
|
||||
@ -221,6 +236,7 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
||||
@ -231,15 +247,24 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1"
|
||||
in str(exc.exception)
|
||||
)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# final check: no warnings/exceptions thrown if it is correct, and file is saved
|
||||
# Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
|
||||
config = GenerationConfig()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Catch warnings
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
config.save_pretrained(tmp_dir)
|
||||
# Catch logs (up to WARNING level, the default level)
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
self.assertEqual(len(os.listdir(tmp_dir)), 1)
|
||||
|
||||
def test_generation_mode(self):
|
||||
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||
|
@ -202,4 +202,4 @@ class Seq2seqTrainerTester(TestCasePlus):
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda x: {"samples": x[0].shape[0]},
|
||||
)
|
||||
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))
|
||||
self.assertIn("Fix these issues to train your model", str(exc.exception))
|
||||
|
Loading…
Reference in New Issue
Block a user