mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[generate] model defaults being inherited only happens for newer models (#36881)
This commit is contained in:
parent
f19d018bff
commit
94f487626a
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -1552,7 +1553,7 @@ class GenerationMixin:
|
||||
return generation_config
|
||||
|
||||
def _prepare_generation_config(
|
||||
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
|
||||
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
|
||||
) -> Tuple[GenerationConfig, Dict]:
|
||||
"""
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
||||
@ -1591,23 +1592,38 @@ class GenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
|
||||
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
|
||||
if not using_model_generation_config:
|
||||
modified_values = {}
|
||||
default_generation_config = GenerationConfig()
|
||||
for key, default_value in default_generation_config.__dict__.items():
|
||||
if key.startswith("_"): # metadata
|
||||
continue
|
||||
custom_gen_config_value = getattr(generation_config, key)
|
||||
model_gen_config_value = getattr(self.generation_config, key)
|
||||
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||
modified_values[key] = model_gen_config_value
|
||||
setattr(generation_config, key, model_gen_config_value)
|
||||
if len(modified_values) > 0:
|
||||
logger.warning_once(
|
||||
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||
)
|
||||
# If `generation_config` is provided:
|
||||
# - `use_model_defaults`: let's fallback ALL default values to the model's generation config
|
||||
# - otherwise: legacy behavior, let's just make sure we have the tokens defined
|
||||
model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
|
||||
if use_model_defaults is True or (
|
||||
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
|
||||
):
|
||||
modified_values = {}
|
||||
default_generation_config = GenerationConfig()
|
||||
for key, default_value in default_generation_config.__dict__.items():
|
||||
if key.startswith("_") or key == "transformers_version": # metadata
|
||||
continue
|
||||
custom_gen_config_value = getattr(generation_config, key)
|
||||
model_gen_config_value = getattr(self.generation_config, key)
|
||||
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||
modified_values[key] = model_gen_config_value
|
||||
setattr(generation_config, key, model_gen_config_value)
|
||||
if len(modified_values) > 0:
|
||||
logger.warning_once(
|
||||
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||
)
|
||||
else:
|
||||
if generation_config.bos_token_id is None:
|
||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||
if generation_config.eos_token_id is None:
|
||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||
if generation_config.decoder_start_token_id is None:
|
||||
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
||||
|
||||
# Finally, apply any passed kwargs
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
@ -1967,6 +1983,7 @@ class GenerationMixin:
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_model_defaults: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2031,6 +2048,11 @@ class GenerationMixin:
|
||||
size. This is an experimental feature, subject to breaking API changes in future versions.
|
||||
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Attention_mask for `negative_prompt_ids`.
|
||||
use_model_defaults (`bool`, *optional*):
|
||||
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
|
||||
generation configuration (`model.generation_config`), as opposed to the global defaults
|
||||
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
|
||||
`True`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@ -2058,7 +2080,9 @@ class GenerationMixin:
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
||||
|
||||
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||
generation_config, model_kwargs = self._prepare_generation_config(
|
||||
generation_config, use_model_defaults, **kwargs
|
||||
)
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
||||
|
||||
|
@ -575,8 +575,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||
"""
|
||||
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
|
||||
-- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
"""
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
attn_implementation = "sdpa"
|
||||
@ -594,12 +594,16 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
self.assertGreater(input_size, model.config.sliding_window)
|
||||
|
||||
generation_config = GenerationConfig(max_new_tokens=20)
|
||||
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
# Generation works beyond sliding window
|
||||
self.assertGreater(out.shape[1], model.config.sliding_window)
|
||||
self.assertEqual(out.shape[1], input_size + 5)
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
# Note: Auto-inheritance only works for models saved starting from 4.50.0
|
||||
model.generation_config.transformers_version = "4.49.0"
|
||||
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
Loading…
Reference in New Issue
Block a user