Generate: lower severity of parameterization checks (#25407)

This commit is contained in:
Joao Gante 2023-08-09 13:15:06 +01:00 committed by GitHub
parent ef74da6582
commit eb3ded16f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin):
raise err
# Validate the values of the attributes
self.validate()
self.validate(is_init=True)
def __eq__(self, other):
if not isinstance(other, GenerationConfig):
@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin):
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def validate(self):
def validate(self, is_init=False):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin):
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
# Validation of attribute relations:
fix_location = ""
if is_init:
fix_location = (
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
+ fix_location
)
if self.temperature != 1.0:
raise ValueError(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature)
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p))
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.typical_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p))
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k))
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff != 0.0:
raise ValueError(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff)
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff != 0.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff))
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
)
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
"beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location
)
if self.early_stopping is not False:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping)
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups != 1:
raise ValueError(
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
),
UserWarning,
)
if self.diversity_penalty != 0.0:
raise ValueError(
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
)
),
UserWarning,
)
if self.length_penalty != 1.0:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty)
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
)
if self.constraints is not None:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints)
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
UserWarning,
)
# 3. detect incorrect paramaterization specific to advanced beam modes
@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin):
constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue."
"{flag_name} to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(