Config: warning when saving generation kwargs in the model config (#28514)

This commit is contained in:
Joao Gante 2024-01-16 18:31:01 +00:00 committed by GitHub
parent 7142bdfa90
commit f4f57f9dfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 107 additions and 32 deletions

View File

@ -277,6 +277,7 @@ class PretrainedConfig(PushToHubMixin):
self.tie_word_embeddings = kwargs.pop( self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True "tie_word_embeddings", True
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models. ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
@ -285,33 +286,10 @@ class PretrainedConfig(PushToHubMixin):
self.add_cross_attention = kwargs.pop("add_cross_attention", False) self.add_cross_attention = kwargs.pop("add_cross_attention", False)
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
# Parameters for sequence generation # Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
self.max_length = kwargs.pop("max_length", 20) # parameters, saving them will be deprecated. In a distant future, we won't need to load them.
self.min_length = kwargs.pop("min_length", 0) for parameter_name, default_value in self._get_generation_defaults().items():
self.do_sample = kwargs.pop("do_sample", False) setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
self.output_scores = kwargs.pop("output_scores", False)
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
@ -463,6 +441,18 @@ class PretrainedConfig(PushToHubMixin):
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
non_default_generation_parameters = {}
for parameter_name, default_value in self._get_generation_defaults().items():
if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
non_default_generation_parameters[parameter_name] = getattr(self, parameter_name)
if len(non_default_generation_parameters) > 0:
logger.warning(
"Some non-default generation parameters are set in the model config. These should go into a "
"GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
"instead. This warning will be raised to an exception in v4.41.\n"
f"Non-default generation parameters: {str(non_default_generation_parameters)}"
)
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
if push_to_hub: if push_to_hub:
@ -1050,6 +1040,45 @@ class PretrainedConfig(PushToHubMixin):
cls._auto_class = auto_class cls._auto_class = auto_class
@staticmethod
def _get_generation_defaults() -> Dict[str, Any]:
return {
"max_length": 20,
"min_length": 0,
"do_sample": False,
"early_stopping": False,
"num_beams": 1,
"num_beam_groups": 1,
"diversity_penalty": 0.0,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
"typical_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"encoder_no_repeat_ngram_size": 0,
"bad_words_ids": None,
"num_return_sequences": 1,
"output_scores": False,
"return_dict_in_generate": False,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"remove_invalid_values": False,
"exponential_decay_length_penalty": None,
"suppress_tokens": None,
"begin_suppress_tokens": None,
}
def _has_non_default_generation_parameters(self) -> bool:
"""
Whether or not this instance holds non-default generation parameters.
"""
for parameter_name, default_value in self._get_generation_defaults().items():
if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
return True
return False
def get_configuration_file(configuration_files: List[str]) -> str: def get_configuration_file(configuration_files: List[str]) -> str:
""" """

View File

@ -1274,11 +1274,14 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config) # priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None: if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met # three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field); # 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same). # 2) the generation config must have seen no modification since its creation (the hash is the same);
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( # 3) the user must have set generation parameters in the model config.
self.generation_config if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
): ):
new_generation_config = GenerationConfig.from_model_config(self.config) new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: if new_generation_config != self.generation_config:

View File

@ -2335,6 +2335,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not _hf_peft_config_loaded: if not _hf_peft_config_loaded:
model_to_save.config.save_pretrained(save_directory) model_to_save.config.save_pretrained(save_directory)
if self.can_generate(): if self.can_generate():
# generation config built from the model config + the model config holds generation kwargs -> generate
# may revert to legacy behavior if the two don't match
if (
model_to_save.generation_config._from_model_config
and model_to_save.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(model_to_save.config)
if new_generation_config != model_to_save.generation_config:
logger.warning(
"Your generation config was originally created from the model config, but the model "
"config has changed since then. Unless you pass the `generation_config` argument to this "
"model's `generate` calls, they will revert to the legacy behavior where the base "
"`generate` parameterization is loaded from the model config instead. "
"To avoid this behavior and this warning, we recommend you to overwrite the generation "
"config model attribute before calling the model's `save_pretrained`, preferably also "
"removing any generation kwargs from the model config. This warning will be raised to an "
"exception in v4.41."
)
model_to_save.generation_config.save_pretrained(save_directory) model_to_save.generation_config.save_pretrained(save_directory)
if _hf_peft_config_loaded: if _hf_peft_config_loaded:

View File

@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
pixel_values = floats_tensor((2, 3, 30, 30)) pixel_values = floats_tensor((2, 3, 30, 30))
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2") model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
model.config.decoder.eos_token_id = None model.generation_config.eos_token_id = None
if is_pt: if is_pt:
pixel_values = pixel_values.to(torch_device) pixel_values = pixel_values.to(torch_device)
model = model.to(torch_device) model = model.to(torch_device)

View File

@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
old_transformers.configuration_utils.__version__ = "v3.0.0" old_transformers.configuration_utils.__version__ = "v3.0.0"
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo) old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
self.assertEqual(old_configuration.hidden_size, 768) self.assertEqual(old_configuration.hidden_size, 768)
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertLogs("transformers.configuration_utils", level="WARNING") as logs:
config.save_pretrained(tmp_dir)
self.assertEqual(len(logs.output), 1)
self.assertIn("min_length", logs.output[0])
def test_has_non_default_generation_parameters(self):
config = BertConfig()
self.assertFalse(config._has_non_default_generation_parameters())
config = BertConfig(min_length=3)
self.assertTrue(config._has_non_default_generation_parameters())
config = BertConfig(min_length=0) # `min_length = 0` is a default generation kwarg
self.assertFalse(config._has_non_default_generation_parameters())

View File

@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
def test_modifying_model_config_causes_warning_saving_generation_config(self):
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.config.top_k = 1
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
model.save_pretrained(tmp_dir)
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
@slow @slow
@require_torch @require_torch