mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Generate: unset GenerationConfig parameters do not raise warning (#29119)
This commit is contained in:
parent
7d312ad2e9
commit
a7755d2409
@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Parameters that control the length of the output
|
||||
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
||||
self.min_length = kwargs.pop("min_length", 0)
|
||||
@ -407,32 +406,34 @@ class GenerationConfig(PushToHubMixin):
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
+ fix_location
|
||||
)
|
||||
if self.temperature != 1.0:
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_p != 1.0:
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.typical_p != 1.0:
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
|
||||
UserWarning,
|
||||
)
|
||||
if self.epsilon_cutoff != 0.0:
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
if self.eta_cutoff != 0.0:
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
|
||||
UserWarning,
|
||||
@ -453,21 +454,21 @@ class GenerationConfig(PushToHubMixin):
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
|
||||
UserWarning,
|
||||
)
|
||||
if self.num_beam_groups != 1:
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.diversity_penalty != 0.0:
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.length_penalty != 1.0:
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
|
||||
UserWarning,
|
||||
@ -491,7 +492,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups != 1:
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
@ -1000,6 +1001,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
setattr(self, key, value)
|
||||
to_remove.append(key)
|
||||
|
||||
# remove all the attributes that were updated, without modifying the input dict
|
||||
# Confirm that the updated instance is still valid
|
||||
self.validate()
|
||||
|
||||
# Remove all the attributes that were updated, without modifying the input dict
|
||||
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
||||
return unused_kwargs
|
||||
|
@ -330,7 +330,6 @@ class FlaxGenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
|
||||
|
@ -736,7 +736,6 @@ class TFGenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
|
||||
|
@ -1347,7 +1347,6 @@ class GenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
|
@ -152,7 +152,6 @@ class QuantizationConfigMixin:
|
||||
config_dict = self.to_dict()
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
# Copied from transformers.generation.configuration_utils.GenerationConfig.update
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
|
||||
@ -171,7 +170,7 @@ class QuantizationConfigMixin:
|
||||
setattr(self, key, value)
|
||||
to_remove.append(key)
|
||||
|
||||
# remove all the attributes that were updated, without modifying the input dict
|
||||
# Remove all the attributes that were updated, without modifying the input dict
|
||||
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
||||
return unused_kwargs
|
||||
|
||||
|
@ -124,26 +124,44 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
# Case 1: A correct configuration will not throw any warning
|
||||
# A correct configuration will not throw any warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# parameters with `do_sample=False`). May be escalated to an error in the future.
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(temperature=0.5)
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
|
||||
# Case 3: Impossible sets of contraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(num_return_sequences=2)
|
||||
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
||||
# that is done by unsetting the parameter (i.e. setting it to None)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# BAD - 0.9 means it is still set, we should warn
|
||||
generation_config_bad_temperature.update(temperature=0.9)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
|
||||
generation_config_bad_temperature.update(temperature=1.0)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# OK - None means it is unset, nothing to warn about
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
|
||||
# Impossible sets of contraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# Passing `generate()`-only flags to `validate` will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Case 5: Model-specific parameters will NOT raise an exception or a warning
|
||||
# Model-specific parameters will NOT raise an exception or a warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
Loading…
Reference in New Issue
Block a user