[generate] add soft deprecations on custom generation methods (#38406)

soft deprecations
This commit is contained in:
Joao Gante 2025-06-02 11:11:46 +01:00 committed by GitHub
parent a75b9ffb5c
commit fe5bfaa4b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2342,8 +2342,8 @@ class GenerationMixin(ContinuousMixin):
- [`~generation.GenerateBeamEncoderDecoderOutput`] - [`~generation.GenerateBeamEncoderDecoderOutput`]
""" """
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
if custom_generate is not None:
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
if custom_generate is not None:
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
# trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
@ -2564,6 +2564,11 @@ class GenerationMixin(ContinuousMixin):
**model_kwargs, **model_kwargs,
) )
elif generation_mode == GenerationMode.DOLA_GENERATION: elif generation_mode == GenerationMode.DOLA_GENERATION:
if not trust_remote_code:
logger.warning_once(
"DoLa Decoding is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
)
if self._is_stateful: if self._is_stateful:
# DoLa decoding was not designed for stateful models, and would require some changes # DoLa decoding was not designed for stateful models, and would require some changes
raise ValueError( raise ValueError(
@ -2581,6 +2586,11 @@ class GenerationMixin(ContinuousMixin):
) )
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if not trust_remote_code:
logger.warning_once(
"Contrastive Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
)
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`") raise ValueError("Contrastive search requires `use_cache=True`")
if self._is_stateful: if self._is_stateful:
@ -2638,6 +2648,10 @@ class GenerationMixin(ContinuousMixin):
) )
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
logger.warning_once(
"Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
)
# 11. prepare beam search scorer # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
@ -2668,6 +2682,10 @@ class GenerationMixin(ContinuousMixin):
) )
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
logger.warning_once(
"Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
)
final_constraints = [] final_constraints = []
if generation_config.constraints is not None: if generation_config.constraints is not None:
final_constraints = generation_config.constraints final_constraints = generation_config.constraints