Generate: unset GenerationConfig parameters do not raise warning (#29119)

This commit is contained in:
Joao Gante 2024-02-20 11:34:31 +00:00 committed by GitHub
parent 7d312ad2e9
commit a7755d2409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 43 additions and 25 deletions

View File

@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):
def __init__(self, **kwargs):
# Parameters that control the length of the output
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
@ -407,32 +406,34 @@ class GenerationConfig(PushToHubMixin):
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature != 1.0:
if self.temperature is not None and self.temperature != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p != 1.0:
if self.top_p is not None and self.top_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.typical_p != 1.0:
if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff != 0.0:
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff != 0.0:
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
@ -453,21 +454,21 @@ class GenerationConfig(PushToHubMixin):
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups != 1:
if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
),
UserWarning,
)
if self.diversity_penalty != 0.0:
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
),
UserWarning,
)
if self.length_penalty != 1.0:
if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
@ -491,7 +492,7 @@ class GenerationConfig(PushToHubMixin):
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups != 1:
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
@ -1000,6 +1001,9 @@ class GenerationConfig(PushToHubMixin):
setattr(self, key, value)
to_remove.append(key)
# remove all the attributes that were updated, without modifying the input dict
# Confirm that the updated instance is still valid
self.validate()
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

View File

@ -330,7 +330,6 @@ class FlaxGenerationMixin:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()

View File

@ -736,7 +736,6 @@ class TFGenerationMixin:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)

View File

@ -1347,7 +1347,6 @@ class GenerationMixin:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined

View File

@ -152,7 +152,6 @@ class QuantizationConfigMixin:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
# Copied from transformers.generation.configuration_utils.GenerationConfig.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
@ -171,7 +170,7 @@ class QuantizationConfigMixin:
setattr(self, key, value)
to_remove.append(key)
# remove all the attributes that were updated, without modifying the input dict
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

View File

@ -124,26 +124,44 @@ class GenerationConfigTest(unittest.TestCase):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
# A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
# Case 3: Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
# Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
# Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")
# Case 5: Model-specific parameters will NOT raise an exception or a warning
# Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)