Generate: add config-level validation (#25381)

This commit is contained in:
Joao Gante 2023-08-08 13:53:03 +01:00 committed by GitHub
parent 9e57e0c063
commit 5bd8c011bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 111 additions and 61 deletions

View File

@ -332,12 +332,121 @@ class GenerationConfig(PushToHubMixin):
def validate(self):
"""
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
the values are invalid.
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
model, such as parameters related to the generation length.
"""
# Validation of individual attributes
if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
# Validation of attribute relations:
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
)
if self.temperature != 1.0:
raise ValueError(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature)
)
if self.top_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p))
if self.typical_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p))
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k))
if self.epsilon_cutoff != 0.0:
raise ValueError(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff)
)
if self.eta_cutoff != 0.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff))
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
)
if self.early_stopping is not False:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping)
)
if self.num_beam_groups != 1:
raise ValueError(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
if self.diversity_penalty != 0.0:
raise ValueError(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
)
)
if self.length_penalty != 1.0:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty)
)
if self.constraints is not None:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints)
)
# 3. detect incorrect paramaterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None:
constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue."
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@ -1493,13 +1493,6 @@ class GenerationMixin:
# 7. determine generation mode
generation_mode = self._get_generation_mode(generation_config, assistant_model)
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
@ -1572,12 +1565,6 @@ class GenerationMixin:
**model_kwargs,
)
if generation_mode == GenerationMode.GREEDY_SEARCH:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
return self.greedy_search(
input_ids,
@ -1593,11 +1580,6 @@ class GenerationMixin:
)
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing contrastive search, "
f"but is {generation_config.num_return_sequences}."
)
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")
@ -1645,12 +1627,6 @@ class GenerationMixin:
)
elif generation_mode == GenerationMode.BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
@ -1686,8 +1662,6 @@ class GenerationMixin:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences,
@ -1722,24 +1696,6 @@ class GenerationMixin:
)
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if generation_config.diversity_penalty == 0.0:
raise ValueError(
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
if not has_default_typical_p:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
@ -1773,21 +1729,6 @@ class GenerationMixin:
)
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if generation_config.num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
if generation_config.do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if generation_config.constraints is not None:
final_constraints = generation_config.constraints