mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Generation, Gemma 3] When passing a custom generation_config
, overwrite default values with the model's base generation_config
(#36684)
This commit is contained in:
parent
f263e88dcf
commit
fc8764c9a6
@ -73,7 +73,7 @@ if is_torch_available():
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
ALL_CACHE_IMPLEMENTATIONS = (
|
||||
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"]
|
||||
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"]
|
||||
)
|
||||
|
||||
|
||||
@ -175,6 +175,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
cache_implementation (`str`, *optional*, default to `None`):
|
||||
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
|
||||
|
||||
- `"dynamic"`: [`DynamicCache`]
|
||||
- `"static"`: [`StaticCache`]
|
||||
- `"offloaded_static"`: [`OffloadedStaticCache`]
|
||||
- `"sliding_window"`: [`SlidingWindowCache`]
|
||||
@ -182,9 +183,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
- `"mamba"`: [`MambaCache`]
|
||||
- `"quantized"`: [`QuantizedCache`]
|
||||
|
||||
We support other cache types, but they must be manually instantiated and
|
||||
passed to `generate` through the `past_key_values` argument. See our
|
||||
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
|
||||
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
|
||||
our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
|
||||
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
||||
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
||||
it will be converted to its repsective `CacheConfig` internally.
|
||||
|
@ -1177,21 +1177,37 @@ class GenerationMixin:
|
||||
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
||||
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
||||
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
|
||||
"""
|
||||
Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
|
||||
processor/criteria is present on both lists, use the user-defined one.
|
||||
|
||||
(Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.)
|
||||
"""
|
||||
if len(custom_list) == 0:
|
||||
return default_list
|
||||
|
||||
final_list = type(default_list)()
|
||||
for default in default_list:
|
||||
using_custom = False
|
||||
for custom in custom_list:
|
||||
if type(custom) is type(default):
|
||||
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
|
||||
raise ValueError(
|
||||
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
||||
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
|
||||
" created by passing the corresponding arguments to generate or by the model's config default"
|
||||
f" values. If you just want to change the default values of {object_type} consider passing"
|
||||
f" them as arguments to `.generate()` instead of using a custom {object_type}."
|
||||
logger.warning_once(
|
||||
f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
|
||||
f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
|
||||
f"will take precedence. Please check the docstring of {type(custom)} to see related "
|
||||
"`.generate()` flags."
|
||||
)
|
||||
default_list.extend(custom_list)
|
||||
return default_list
|
||||
final_list.append(custom)
|
||||
using_custom = True
|
||||
break
|
||||
if not using_custom:
|
||||
final_list.append(default)
|
||||
|
||||
for custom in custom_list:
|
||||
if custom not in final_list:
|
||||
final_list.append(custom)
|
||||
return final_list
|
||||
|
||||
def compute_transition_scores(
|
||||
self,
|
||||
@ -1573,17 +1589,28 @@ class GenerationMixin:
|
||||
# exception will be raised in `_validate_model_kwargs`
|
||||
if not is_torchdynamo_compiling():
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
||||
|
||||
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
|
||||
# TODO (joao): per-model generation config classes.
|
||||
if not using_model_generation_config:
|
||||
if generation_config.bos_token_id is None:
|
||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||
if generation_config.eos_token_id is None:
|
||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||
if generation_config.decoder_start_token_id is None:
|
||||
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
||||
modified_values = {}
|
||||
default_generation_config = GenerationConfig()
|
||||
for key, default_value in default_generation_config.__dict__.items():
|
||||
if key.startswith("_"): # metadata
|
||||
continue
|
||||
custom_gen_config_value = getattr(generation_config, key)
|
||||
model_gen_config_value = getattr(self.generation_config, key)
|
||||
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||
modified_values[key] = model_gen_config_value
|
||||
setattr(generation_config, key, model_gen_config_value)
|
||||
if len(modified_values) > 0:
|
||||
logger.warning_once(
|
||||
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||
)
|
||||
|
||||
# Finally, apply any passed kwargs
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
@ -1837,6 +1864,8 @@ class GenerationMixin:
|
||||
model_kwargs[cache_name] = cache_class(cache_config)
|
||||
elif generation_config.cache_implementation == "offloaded":
|
||||
model_kwargs[cache_name] = OffloadedCache()
|
||||
elif generation_config.cache_implementation == "dynamic":
|
||||
model_kwargs[cache_name] = DynamicCache()
|
||||
|
||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||
# keeps copying the cache thus using much more memory
|
||||
|
@ -1162,8 +1162,8 @@ class GenerationTesterMixin:
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(low_output, high_output)
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
@ -261,6 +262,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@ -269,6 +271,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
@ -81,6 +82,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@ -89,6 +91,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
@ -299,12 +299,13 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
@ -96,6 +97,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@ -104,6 +106,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
@ -23,6 +24,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
GenerationConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||
"""
|
||||
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
"""
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
generation_config = GenerationConfig(max_new_tokens=20)
|
||||
|
||||
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
@ -351,6 +352,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@ -359,6 +361,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
def test_left_padding_compatibility(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
@ -423,6 +423,7 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user