mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Generate: add config-level validation (#25381)
This commit is contained in:
parent
9e57e0c063
commit
5bd8c011bb
@ -332,12 +332,121 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""
|
"""
|
||||||
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
|
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||||
the values are invalid.
|
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||||
|
|
||||||
|
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
|
||||||
|
model, such as parameters related to the generation length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Validation of individual attributes
|
||||||
if self.early_stopping not in {True, False, "never"}:
|
if self.early_stopping not in {True, False, "never"}:
|
||||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||||
|
|
||||||
|
# Validation of attribute relations:
|
||||||
|
# 1. detect sampling-only parameterization when not in sampling mode
|
||||||
|
if self.do_sample is False:
|
||||||
|
greedy_wrong_parameter_msg = (
|
||||||
|
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
|
||||||
|
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
|
||||||
|
)
|
||||||
|
if self.temperature != 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature)
|
||||||
|
)
|
||||||
|
if self.top_p != 1.0:
|
||||||
|
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p))
|
||||||
|
if self.typical_p != 1.0:
|
||||||
|
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p))
|
||||||
|
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
|
||||||
|
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k))
|
||||||
|
if self.epsilon_cutoff != 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff)
|
||||||
|
)
|
||||||
|
if self.eta_cutoff != 0.0:
|
||||||
|
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff))
|
||||||
|
|
||||||
|
# 2. detect beam-only parameterization when not in beam mode
|
||||||
|
if self.num_beams == 1:
|
||||||
|
single_beam_wrong_parameter_msg = (
|
||||||
|
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
|
||||||
|
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
|
||||||
|
)
|
||||||
|
if self.early_stopping is not False:
|
||||||
|
raise ValueError(
|
||||||
|
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping)
|
||||||
|
)
|
||||||
|
if self.num_beam_groups != 1:
|
||||||
|
raise ValueError(
|
||||||
|
single_beam_wrong_parameter_msg.format(
|
||||||
|
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.diversity_penalty != 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
single_beam_wrong_parameter_msg.format(
|
||||||
|
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.length_penalty != 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty)
|
||||||
|
)
|
||||||
|
if self.constraints is not None:
|
||||||
|
raise ValueError(
|
||||||
|
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. detect incorrect paramaterization specific to advanced beam modes
|
||||||
|
else:
|
||||||
|
# constrained beam search
|
||||||
|
if self.constraints is not None:
|
||||||
|
constrained_wrong_parameter_msg = (
|
||||||
|
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
|
||||||
|
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
|
||||||
|
"{flag_name} to continue."
|
||||||
|
)
|
||||||
|
if self.do_sample is True:
|
||||||
|
raise ValueError(
|
||||||
|
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||||
|
)
|
||||||
|
if self.num_beam_groups != 1:
|
||||||
|
raise ValueError(
|
||||||
|
constrained_wrong_parameter_msg.format(
|
||||||
|
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# group beam search
|
||||||
|
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||||
|
group_error_prefix = (
|
||||||
|
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||||
|
"this generation mode, "
|
||||||
|
)
|
||||||
|
if self.do_sample is True:
|
||||||
|
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||||
|
if self.num_beams % self.num_beam_groups != 0:
|
||||||
|
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||||
|
if self.diversity_penalty == 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
group_error_prefix
|
||||||
|
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. check `num_return_sequences`
|
||||||
|
if self.num_return_sequences != 1:
|
||||||
|
if self.num_beams == 1:
|
||||||
|
if self.do_sample is False:
|
||||||
|
raise ValueError(
|
||||||
|
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||||
|
f"(got {self.num_return_sequences})."
|
||||||
|
)
|
||||||
|
elif self.num_return_sequences > self.num_beams:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||||
|
f"({self.num_beams})."
|
||||||
|
)
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
|
@ -1493,13 +1493,6 @@ class GenerationMixin:
|
|||||||
# 7. determine generation mode
|
# 7. determine generation mode
|
||||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||||
|
|
||||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
|
||||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
|
||||||
if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
|
|
||||||
raise ValueError(
|
|
||||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if streamer is not None and (generation_config.num_beams > 1):
|
if streamer is not None and (generation_config.num_beams > 1):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
||||||
@ -1572,12 +1565,6 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||||
if generation_config.num_return_sequences > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"num_return_sequences has to be 1 when doing greedy search, "
|
|
||||||
f"but is {generation_config.num_return_sequences}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
return self.greedy_search(
|
return self.greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1593,11 +1580,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||||
if generation_config.num_return_sequences > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"num_return_sequences has to be 1 when doing contrastive search, "
|
|
||||||
f"but is {generation_config.num_return_sequences}."
|
|
||||||
)
|
|
||||||
if not model_kwargs["use_cache"]:
|
if not model_kwargs["use_cache"]:
|
||||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||||
|
|
||||||
@ -1645,12 +1627,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode == GenerationMode.BEAM_SEARCH:
|
elif generation_mode == GenerationMode.BEAM_SEARCH:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
|
||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
|
||||||
|
|
||||||
# 11. prepare beam search scorer
|
# 11. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -1686,8 +1662,6 @@ class GenerationMixin:
|
|||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config)
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
|
||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
|
||||||
# 12. prepare beam search scorer
|
# 12. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size * generation_config.num_return_sequences,
|
batch_size=batch_size * generation_config.num_return_sequences,
|
||||||
@ -1722,24 +1696,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
|
||||||
|
|
||||||
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
|
||||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
|
||||||
|
|
||||||
if generation_config.diversity_penalty == 0.0:
|
|
||||||
raise ValueError(
|
|
||||||
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
|
|
||||||
)
|
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
|
||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
|
||||||
|
|
||||||
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
|
||||||
if not has_default_typical_p:
|
|
||||||
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
|
|
||||||
|
|
||||||
# 11. prepare beam search scorer
|
# 11. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -1773,21 +1729,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
|
||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
|
||||||
|
|
||||||
if generation_config.num_beams <= 1:
|
|
||||||
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
|
|
||||||
|
|
||||||
if generation_config.do_sample:
|
|
||||||
raise ValueError("`do_sample` needs to be false for constrained generation.")
|
|
||||||
|
|
||||||
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
|
|
||||||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
|
||||||
|
|
||||||
final_constraints = []
|
final_constraints = []
|
||||||
if generation_config.constraints is not None:
|
if generation_config.constraints is not None:
|
||||||
final_constraints = generation_config.constraints
|
final_constraints = generation_config.constraints
|
||||||
|
Loading…
Reference in New Issue
Block a user