mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: lower severity of parameterization checks (#25407)
This commit is contained in:
parent
ef74da6582
commit
eb3ded16f7
@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate()
|
||||
self.validate(is_init=True)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, GenerationConfig):
|
||||
@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def validate(self):
|
||||
def validate(self, is_init=False):
|
||||
"""
|
||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||
@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||
|
||||
# Validation of attribute relations:
|
||||
fix_location = ""
|
||||
if is_init:
|
||||
fix_location = (
|
||||
" This was detected when initializing the generation config instance, which means the corresponding "
|
||||
"file may hold incorrect parameterization and should be fixed."
|
||||
)
|
||||
|
||||
# 1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
|
||||
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
|
||||
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
|
||||
+ fix_location
|
||||
)
|
||||
if self.temperature != 1.0:
|
||||
raise ValueError(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature)
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_p != 1.0:
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p))
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.typical_p != 1.0:
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p))
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k))
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
|
||||
UserWarning,
|
||||
)
|
||||
if self.epsilon_cutoff != 0.0:
|
||||
raise ValueError(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff)
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
if self.eta_cutoff != 0.0:
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff))
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
|
||||
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
|
||||
"beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping)
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
|
||||
UserWarning,
|
||||
)
|
||||
if self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.diversity_penalty != 0.0:
|
||||
raise ValueError(
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
)
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.length_penalty != 1.0:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty)
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
|
||||
UserWarning,
|
||||
)
|
||||
if self.constraints is not None:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints)
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 3. detect incorrect paramaterization specific to advanced beam modes
|
||||
@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
constrained_wrong_parameter_msg = (
|
||||
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
|
||||
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
|
||||
"{flag_name} to continue."
|
||||
"{flag_name} to continue." + fix_location
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user