mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Generate: add config-level validation (#25381)
This commit is contained in:
parent
9e57e0c063
commit
5bd8c011bb
@ -332,12 +332,121 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
|
||||
the values are invalid.
|
||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||
|
||||
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
|
||||
model, such as parameters related to the generation length.
|
||||
"""
|
||||
|
||||
# Validation of individual attributes
|
||||
if self.early_stopping not in {True, False, "never"}:
|
||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||
|
||||
# Validation of attribute relations:
|
||||
# 1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
|
||||
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
|
||||
)
|
||||
if self.temperature != 1.0:
|
||||
raise ValueError(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature)
|
||||
)
|
||||
if self.top_p != 1.0:
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p))
|
||||
if self.typical_p != 1.0:
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p))
|
||||
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k))
|
||||
if self.epsilon_cutoff != 0.0:
|
||||
raise ValueError(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff)
|
||||
)
|
||||
if self.eta_cutoff != 0.0:
|
||||
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff))
|
||||
|
||||
# 2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
|
||||
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping)
|
||||
)
|
||||
if self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
if self.diversity_penalty != 0.0:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
)
|
||||
)
|
||||
if self.length_penalty != 1.0:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty)
|
||||
)
|
||||
if self.constraints is not None:
|
||||
raise ValueError(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints)
|
||||
)
|
||||
|
||||
# 3. detect incorrect paramaterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
|
||||
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
|
||||
"{flag_name} to continue."
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
|
||||
# 4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
@ -1493,13 +1493,6 @@ class GenerationMixin:
|
||||
# 7. determine generation mode
|
||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||
|
||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
|
||||
raise ValueError(
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
)
|
||||
|
||||
if streamer is not None and (generation_config.num_beams > 1):
|
||||
raise ValueError(
|
||||
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
||||
@ -1572,12 +1565,6 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
"num_return_sequences has to be 1 when doing greedy search, "
|
||||
f"but is {generation_config.num_return_sequences}."
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
@ -1593,11 +1580,6 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
"num_return_sequences has to be 1 when doing contrastive search, "
|
||||
f"but is {generation_config.num_return_sequences}."
|
||||
)
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||
|
||||
@ -1645,12 +1627,6 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.BEAM_SEARCH:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
@ -1686,8 +1662,6 @@ class GenerationMixin:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
# 12. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size * generation_config.num_return_sequences,
|
||||
@ -1722,24 +1696,6 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||
|
||||
if generation_config.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
|
||||
)
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
||||
if not has_default_typical_p:
|
||||
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
@ -1773,21 +1729,6 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
if generation_config.num_beams <= 1:
|
||||
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
|
||||
|
||||
if generation_config.do_sample:
|
||||
raise ValueError("`do_sample` needs to be false for constrained generation.")
|
||||
|
||||
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
|
||||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
||||
|
||||
final_constraints = []
|
||||
if generation_config.constraints is not None:
|
||||
final_constraints = generation_config.constraints
|
||||
|
Loading…
Reference in New Issue
Block a user