[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config (#36684)

This commit is contained in:
Joao Gante 2025-03-15 12:40:09 +00:00 committed by GitHub
parent f263e88dcf
commit fc8764c9a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 108 additions and 24 deletions

View File

@ -73,7 +73,7 @@ if is_torch_available():
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"]
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"]
)
@ -175,6 +175,7 @@ class GenerationConfig(PushToHubMixin):
cache_implementation (`str`, *optional*, default to `None`):
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
- `"dynamic"`: [`DynamicCache`]
- `"static"`: [`StaticCache`]
- `"offloaded_static"`: [`OffloadedStaticCache`]
- `"sliding_window"`: [`SlidingWindowCache`]
@ -182,9 +183,8 @@ class GenerationConfig(PushToHubMixin):
- `"mamba"`: [`MambaCache`]
- `"quantized"`: [`QuantizedCache`]
We support other cache types, but they must be manually instantiated and
passed to `generate` through the `past_key_values` argument. See our
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.

View File

@ -1177,21 +1177,37 @@ class GenerationMixin:
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
"""
Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
processor/criteria is present on both lists, use the user-defined one.
(Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.)
"""
if len(custom_list) == 0:
return default_list
final_list = type(default_list)()
for default in default_list:
using_custom = False
for custom in custom_list:
if type(custom) is type(default):
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `.generate()` instead of using a custom {object_type}."
logger.warning_once(
f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
f"will take precedence. Please check the docstring of {type(custom)} to see related "
"`.generate()` flags."
)
default_list.extend(custom_list)
return default_list
final_list.append(custom)
using_custom = True
break
if not using_custom:
final_list.append(default)
for custom in custom_list:
if custom not in final_list:
final_list.append(custom)
return final_list
def compute_transition_scores(
self,
@ -1573,17 +1589,28 @@ class GenerationMixin:
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
# TODO (joao): per-model generation config classes.
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
modified_values = {}
default_generation_config = GenerationConfig()
for key, default_value in default_generation_config.__dict__.items():
if key.startswith("_"): # metadata
continue
custom_gen_config_value = getattr(generation_config, key)
model_gen_config_value = getattr(self.generation_config, key)
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
modified_values[key] = model_gen_config_value
setattr(generation_config, key, model_gen_config_value)
if len(modified_values) > 0:
logger.warning_once(
f"`generation_config` default values have been modified to match model-specific defaults: "
f"{modified_values}. If this is not desired, please set these values explicitly."
)
# Finally, apply any passed kwargs
model_kwargs = generation_config.update(**kwargs)
else:
model_kwargs = kwargs
@ -1837,6 +1864,8 @@ class GenerationMixin:
model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
elif generation_config.cache_implementation == "dynamic":
model_kwargs[cache_name] = DynamicCache()
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory

View File

@ -1162,8 +1162,8 @@ class GenerationTesterMixin:
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.

View File

@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@ -261,6 +262,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@ -269,6 +271,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@ -16,6 +16,7 @@
import unittest
import pytest
from packaging import version
from parameterized import parameterized
from pytest import mark
@ -81,6 +82,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@ -89,6 +91,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@ -299,12 +299,13 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_matches_greedy_search(self):
pass
@pytest.mark.generate
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_sample(self):
pass

View File

@ -16,6 +16,7 @@
import unittest
import pytest
from packaging import version
from parameterized import parameterized
from pytest import mark
@ -96,6 +97,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@ -104,6 +106,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@ -23,6 +24,7 @@ from transformers import (
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
GenerationConfig,
is_torch_available,
)
from transformers.testing_utils import (
@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
attn_implementation = "sdpa"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=20)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@ -351,6 +352,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@ -359,6 +361,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@ -16,6 +16,8 @@
import unittest
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_model_parallel_beam_search(self):
pass
@pytest.mark.generate
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
pass
@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_left_padding_compatibility(self):
pass
@pytest.mark.generate
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
def test_assisted_decoding_sample(self):
pass

View File

@ -423,6 +423,7 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass