Config: unified logic to retrieve text config (#33219)

This commit is contained in:
Joao Gante 2024-09-04 12:03:30 +01:00 committed by GitHub
parent ebbe8d8014
commit d750b509fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 91 additions and 88 deletions

View File

@ -67,4 +67,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -1019,17 +1019,17 @@ class PretrainedConfig(PushToHubMixin):
"""
non_default_generation_parameters = {}
decoder_attribute_name = None
default_config = None
# Composite models don't have a default config, use their decoder config as a fallback for default values
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
try:
default_config = self.__class__()
except ValueError:
for decoder_attribute_name in ("decoder", "generator", "text_config"):
if hasattr(self, decoder_attribute_name):
default_config = getattr(self, decoder_attribute_name).__class__()
break
decoder_config = self.get_text_config(decoder=True)
if decoder_config is not self:
default_config = decoder_config.__class__()
else:
decoder_config = None
# If it is a composite model, we want to check the subconfig that will be used for generation
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
@ -1057,6 +1057,36 @@ class PretrainedConfig(PushToHubMixin):
return non_default_generation_parameters
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
If `decoder` is set to `True`, then only search for decoder config names.
"""
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",)
if decoder:
possible_text_config_names = decoder_possible_text_config_names
else:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
valid_text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(self, text_config_name):
text_config = getattr(self, text_config_name, None)
if text_config is not None:
valid_text_config_names += [text_config_name]
if len(valid_text_config_names) > 1:
raise ValueError(
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
)
elif len(valid_text_config_names) == 1:
return getattr(self, valid_text_config_names[0])
return self
def get_configuration_file(configuration_files: List[str]) -> str:
"""

View File

@ -1192,25 +1192,30 @@ class GenerationConfig(PushToHubMixin):
"""
config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator", "text_config"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])
# generation config (which in turn is defined from the outer attributes of model config).
decoder_config = model_config.get_text_config(decoder=True)
if decoder_config is not model_config:
default_generation_config = GenerationConfig()
decoder_config_dict = decoder_config.to_dict()
for attr in generation_config.to_dict().keys():
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
if attr in decoder_config_dict and is_unset:
setattr(generation_config, attr, decoder_config_dict[attr])
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
if config.return_dict_in_generate is False:
if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags):
config.return_dict_in_generate = True
if generation_config.return_dict_in_generate is False:
if any(
getattr(generation_config, extra_output_flag, False)
for extra_output_flag in generation_config.extra_output_flags
):
generation_config.return_dict_in_generate = True
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
return config
# Hash to detect whether the instance was modified
generation_config._original_object_hash = hash(generation_config)
return generation_config
def update(self, **kwargs):
"""

View File

@ -209,10 +209,7 @@ def get_modules_to_fuse(model, quantization_config):
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
if not hasattr(model.config, "text_config"):
config = model.config
else:
config = model.config.text_config
config = model.config.get_text_config(decoder=True)
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
hidden_size = config.hidden_size
@ -345,11 +342,8 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
previous_device = gate_proj.qweight.device
# Deal also with the case model has `text_config` attribute
hidden_act = (
model.config.hidden_act
if not hasattr(model.config, "text_config")
else model.config.text_config.hidden_act
)
config = model.config.get_text_config(decoder=True)
hidden_act = config.hidden_act
activation_fn = ACT2FN[hidden_act]
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)

View File

@ -2025,11 +2025,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
vocab_size = model_embeds.weight.shape[0]
# Update base model and current model config
if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = vocab_size
else:
self.config.vocab_size = vocab_size
# Update base model and current model config.
self.config.get_text_config().vocab_size = vocab_size
self.vocab_size = vocab_size
# Tie weights again if needed

View File

@ -735,7 +735,7 @@ class ClvpPreTrainedModel(PreTrainedModel):
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, ClvpEncoder):
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
config = self.config.get_text_config()
factor = config.initializer_factor
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
elif isinstance(module, ClvpConditioningEncoder):

View File

@ -1330,7 +1330,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1381,6 +1381,9 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
batch_size=batch_size,
)
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
@ -1388,7 +1391,6 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs

View File

@ -831,7 +831,7 @@ class GenerationTesterMixin:
# Sample constraints
min_id = 3
max_id = config.vocab_size
max_id = config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
@ -889,7 +889,7 @@ class GenerationTesterMixin:
# Sample constraints
min_id = 3
max_id = model.config.vocab_size
max_id = model.config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
@ -2012,18 +2012,20 @@ class GenerationTesterMixin:
self.assertTrue(output.past_key_values is None)
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)
vocab_size = config.get_text_config(decoder=True).vocab_size
expected_shape = (batch_size, vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_logits(self, batch_size, scores, config):
vocab_size = config.get_text_config(decoder=True).vocab_size
self.assertIsInstance(scores, tuple)
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = config.vocab_size - scores[0].shape[-1]
vocab_diff = vocab_size - scores[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1

View File

@ -1747,12 +1747,13 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
def test_resize_tokens_embeddings(self):
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
@ -1764,18 +1765,15 @@ class ModelTesterMixin:
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
@ -1787,11 +1785,7 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
@ -1817,21 +1811,13 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
@ -1852,13 +1838,10 @@ class ModelTesterMixin:
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
@ -1874,13 +1857,9 @@ class ModelTesterMixin:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
@ -1892,11 +1871,7 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
@ -1988,7 +1963,7 @@ class ModelTesterMixin:
# self.assertTrue(check_same_values(embeddings, decoding))
# Check that after resize they remain tied.
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
vocab_size = config.get_text_config().vocab_size
model_tied.resize_token_embeddings(vocab_size + 10)
params_tied_2 = list(model_tied.parameters())
self.assertEqual(len(params_tied_2), len(params_tied))
@ -4831,7 +4806,7 @@ class ModelTesterMixin:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape
vocab_size = config.vocab_size
vocab_size = config.get_text_config().vocab_size
model = model_class(config).to(device=torch_device).eval()
# num_logits_to_keep=0 is a special case meaning "keep all logits"

View File

@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
# Avoid `IndexError` in embedding layers
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
if tokenizer is not None:
config_vocab_size = getattr(model.config, "vocab_size", None)
# Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen
config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None)
# For CLIP-like models
if config_vocab_size is None:
if hasattr(model.config, "text_config"):
if hasattr(model.config, "text_encoder"):
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
elif hasattr(model.config, "text_encoder"):
config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None)
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
raise ValueError(
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."