Config: unified logic to retrieve text config (#33219)

This commit is contained in:
Joao Gante 2024-09-04 12:03:30 +01:00 committed by GitHub
parent ebbe8d8014
commit d750b509fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 91 additions and 88 deletions

View File

@ -1019,17 +1019,17 @@ class PretrainedConfig(PushToHubMixin):
""" """
non_default_generation_parameters = {} non_default_generation_parameters = {}
decoder_attribute_name = None decoder_attribute_name = None
default_config = None
# Composite models don't have a default config, use their decoder config as a fallback for default values # Composite models don't have a default config, use their decoder config as a fallback for default values
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
try: try:
default_config = self.__class__() default_config = self.__class__()
except ValueError: except ValueError:
for decoder_attribute_name in ("decoder", "generator", "text_config"): decoder_config = self.get_text_config(decoder=True)
if hasattr(self, decoder_attribute_name): if decoder_config is not self:
default_config = getattr(self, decoder_attribute_name).__class__() default_config = decoder_config.__class__()
break else:
decoder_config = None
# If it is a composite model, we want to check the subconfig that will be used for generation # If it is a composite model, we want to check the subconfig that will be used for generation
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name) self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
@ -1057,6 +1057,36 @@ class PretrainedConfig(PushToHubMixin):
return non_default_generation_parameters return non_default_generation_parameters
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
If `decoder` is set to `True`, then only search for decoder config names.
"""
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",)
if decoder:
possible_text_config_names = decoder_possible_text_config_names
else:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
valid_text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(self, text_config_name):
text_config = getattr(self, text_config_name, None)
if text_config is not None:
valid_text_config_names += [text_config_name]
if len(valid_text_config_names) > 1:
raise ValueError(
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
)
elif len(valid_text_config_names) == 1:
return getattr(self, valid_text_config_names[0])
return self
def get_configuration_file(configuration_files: List[str]) -> str: def get_configuration_file(configuration_files: List[str]) -> str:
""" """

View File

@ -1192,25 +1192,30 @@ class GenerationConfig(PushToHubMixin):
""" """
config_dict = model_config.to_dict() config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None) config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True) generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config. # generation config (which in turn is defined from the outer attributes of model config).
for decoder_name in ("decoder", "generator", "text_config"): decoder_config = model_config.get_text_config(decoder=True)
if decoder_name in config_dict: if decoder_config is not model_config:
default_generation_config = GenerationConfig() default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name] decoder_config_dict = decoder_config.to_dict()
for attr in config.to_dict().keys(): for attr in generation_config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
setattr(config, attr, decoder_config[attr]) if attr in decoder_config_dict and is_unset:
setattr(generation_config, attr, decoder_config_dict[attr])
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`. # If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
if config.return_dict_in_generate is False: if generation_config.return_dict_in_generate is False:
if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags): if any(
config.return_dict_in_generate = True getattr(generation_config, extra_output_flag, False)
for extra_output_flag in generation_config.extra_output_flags
):
generation_config.return_dict_in_generate = True
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified # Hash to detect whether the instance was modified
return config generation_config._original_object_hash = hash(generation_config)
return generation_config
def update(self, **kwargs): def update(self, **kwargs):
""" """

View File

@ -209,10 +209,7 @@ def get_modules_to_fuse(model, quantization_config):
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type] current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava) # Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
if not hasattr(model.config, "text_config"): config = model.config.get_text_config(decoder=True)
config = model.config
else:
config = model.config.text_config
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own. # Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -345,11 +342,8 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
previous_device = gate_proj.qweight.device previous_device = gate_proj.qweight.device
# Deal also with the case model has `text_config` attribute # Deal also with the case model has `text_config` attribute
hidden_act = ( config = model.config.get_text_config(decoder=True)
model.config.hidden_act hidden_act = config.hidden_act
if not hasattr(model.config, "text_config")
else model.config.text_config.hidden_act
)
activation_fn = ACT2FN[hidden_act] activation_fn = ACT2FN[hidden_act]
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn) new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)

View File

@ -2025,11 +2025,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
vocab_size = model_embeds.weight.shape[0] vocab_size = model_embeds.weight.shape[0]
# Update base model and current model config # Update base model and current model config.
if hasattr(self.config, "text_config"): self.config.get_text_config().vocab_size = vocab_size
self.config.text_config.vocab_size = vocab_size
else:
self.config.vocab_size = vocab_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
# Tie weights again if needed # Tie weights again if needed

View File

@ -735,7 +735,7 @@ class ClvpPreTrainedModel(PreTrainedModel):
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std) nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std) nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, ClvpEncoder): elif isinstance(module, ClvpEncoder):
config = self.config.text_config if hasattr(self.config, "text_config") else self.config config = self.config.get_text_config()
factor = config.initializer_factor factor = config.initializer_factor
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5)) module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
elif isinstance(module, ClvpConditioningEncoder): elif isinstance(module, ClvpConditioningEncoder):

View File

@ -1330,7 +1330,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
cache_position=None, cache_position=None,
position_ids=None, position_ids=None,
use_cache=True, use_cache=True,
num_logits_to_keep=0, num_logits_to_keep=None,
**kwargs, **kwargs,
): ):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1381,6 +1381,9 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
batch_size=batch_size, batch_size=batch_size,
) )
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
@ -1388,7 +1391,6 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
} }
) )
return model_inputs return model_inputs

View File

@ -831,7 +831,7 @@ class GenerationTesterMixin:
# Sample constraints # Sample constraints
min_id = 3 min_id = 3
max_id = config.vocab_size max_id = config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
@ -889,7 +889,7 @@ class GenerationTesterMixin:
# Sample constraints # Sample constraints
min_id = 3 min_id = 3
max_id = model.config.vocab_size max_id = model.config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
@ -2012,18 +2012,20 @@ class GenerationTesterMixin:
self.assertTrue(output.past_key_values is None) self.assertTrue(output.past_key_values is None)
def _check_scores(self, batch_size, scores, length, config): def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size) vocab_size = config.get_text_config(decoder=True).vocab_size
expected_shape = (batch_size, vocab_size)
self.assertIsInstance(scores, tuple) self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length) self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_logits(self, batch_size, scores, config): def _check_logits(self, batch_size, scores, config):
vocab_size = config.get_text_config(decoder=True).vocab_size
self.assertIsInstance(scores, tuple) self.assertIsInstance(scores, tuple)
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores)) self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models) # vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = config.vocab_size - scores[0].shape[-1] vocab_diff = vocab_size - scores[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1]) self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores)) self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1

View File

@ -1747,12 +1747,13 @@ class ModelTesterMixin:
self.assertTrue(models_equal) self.assertTrue(models_equal)
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
( (
original_config, original_config,
inputs_dict, inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common() ) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config = copy.deepcopy(original_config) config = copy.deepcopy(original_config)
@ -1764,18 +1765,15 @@ class ModelTesterMixin:
if self.model_tester.is_training is False: if self.model_tester.is_training is False:
model.eval() model.eval()
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size model_vocab_size = config.get_text_config().vocab_size
# Retrieve the embeddings and clone theme # Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size) model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone() cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10) model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = ( new_model_vocab_size = model.config.get_text_config().vocab_size
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertEqual(new_model_vocab_size, model_vocab_size + 10) self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix # Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
@ -1787,11 +1785,7 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15) model_embed = model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = ( new_model_vocab_size = model.config.get_text_config().vocab_size
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertEqual(new_model_vocab_size, model_vocab_size - 15) self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix # Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
@ -1817,21 +1811,13 @@ class ModelTesterMixin:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1) model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
new_model_vocab_size = ( new_model_vocab_size = model.config.get_text_config().vocab_size
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertTrue(new_model_vocab_size + 10, model_vocab_size) self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64) model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
new_model_vocab_size = ( new_model_vocab_size = model.config.get_text_config().vocab_size
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertTrue(model_embed.weight.shape[0] // 64, 0) self.assertTrue(model_embed.weight.shape[0] // 64, 0)
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size) self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
@ -1852,13 +1838,10 @@ class ModelTesterMixin:
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3) model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
def test_resize_embeddings_untied(self): def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings: if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`") self.skipTest(reason="test_resize_embeddings is set to `False`")
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test # if model cannot untied embeddings -> leave test
@ -1874,13 +1857,9 @@ class ModelTesterMixin:
continue continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10) model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = ( new_model_vocab_size = model.config.get_text_config().vocab_size
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertEqual(new_model_vocab_size, model_vocab_size + 10) self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings() output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
@ -1892,11 +1871,7 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15) model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = ( new_model_vocab_size = model.config.get_text_config().vocab_size
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertEqual(new_model_vocab_size, model_vocab_size - 15) self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix # Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings() output_embeds = model.get_output_embeddings()
@ -1988,7 +1963,7 @@ class ModelTesterMixin:
# self.assertTrue(check_same_values(embeddings, decoding)) # self.assertTrue(check_same_values(embeddings, decoding))
# Check that after resize they remain tied. # Check that after resize they remain tied.
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size vocab_size = config.get_text_config().vocab_size
model_tied.resize_token_embeddings(vocab_size + 10) model_tied.resize_token_embeddings(vocab_size + 10)
params_tied_2 = list(model_tied.parameters()) params_tied_2 = list(model_tied.parameters())
self.assertEqual(len(params_tied_2), len(params_tied)) self.assertEqual(len(params_tied_2), len(params_tied))
@ -4831,7 +4806,7 @@ class ModelTesterMixin:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape batch_size, sequence_length = inputs["input_ids"].shape
vocab_size = config.vocab_size vocab_size = config.get_text_config().vocab_size
model = model_class(config).to(device=torch_device).eval() model = model_class(config).to(device=torch_device).eval()
# num_logits_to_keep=0 is a special case meaning "keep all logits" # num_logits_to_keep=0 is a special case meaning "keep all logits"

View File

@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
# Avoid `IndexError` in embedding layers # Avoid `IndexError` in embedding layers
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"] CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
if tokenizer is not None: if tokenizer is not None:
config_vocab_size = getattr(model.config, "vocab_size", None) # Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen
config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None)
# For CLIP-like models # For CLIP-like models
if config_vocab_size is None: if config_vocab_size is None:
if hasattr(model.config, "text_config"): if hasattr(model.config, "text_encoder"):
config_vocab_size = getattr(model.config.text_config, "vocab_size", None) config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
elif hasattr(model.config, "text_encoder"):
config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None)
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE: if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
raise ValueError( raise ValueError(
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`." "Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."