mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Config: unified logic to retrieve text config (#33219)
This commit is contained in:
parent
ebbe8d8014
commit
d750b509fc
@ -67,4 +67,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
@ -1019,17 +1019,17 @@ class PretrainedConfig(PushToHubMixin):
|
||||
"""
|
||||
non_default_generation_parameters = {}
|
||||
decoder_attribute_name = None
|
||||
default_config = None
|
||||
|
||||
# Composite models don't have a default config, use their decoder config as a fallback for default values
|
||||
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
|
||||
try:
|
||||
default_config = self.__class__()
|
||||
except ValueError:
|
||||
for decoder_attribute_name in ("decoder", "generator", "text_config"):
|
||||
if hasattr(self, decoder_attribute_name):
|
||||
default_config = getattr(self, decoder_attribute_name).__class__()
|
||||
break
|
||||
decoder_config = self.get_text_config(decoder=True)
|
||||
if decoder_config is not self:
|
||||
default_config = decoder_config.__class__()
|
||||
else:
|
||||
decoder_config = None
|
||||
|
||||
# If it is a composite model, we want to check the subconfig that will be used for generation
|
||||
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
|
||||
@ -1057,6 +1057,36 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
return non_default_generation_parameters
|
||||
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
If `decoder` is set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
encoder_possible_text_config_names = ("text_encoder",)
|
||||
if decoder:
|
||||
possible_text_config_names = decoder_possible_text_config_names
|
||||
else:
|
||||
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
|
||||
|
||||
valid_text_config_names = []
|
||||
for text_config_name in possible_text_config_names:
|
||||
if hasattr(self, text_config_name):
|
||||
text_config = getattr(self, text_config_name, None)
|
||||
if text_config is not None:
|
||||
valid_text_config_names += [text_config_name]
|
||||
|
||||
if len(valid_text_config_names) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
|
||||
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||
)
|
||||
elif len(valid_text_config_names) == 1:
|
||||
return getattr(self, valid_text_config_names[0])
|
||||
return self
|
||||
|
||||
|
||||
def get_configuration_file(configuration_files: List[str]) -> str:
|
||||
"""
|
||||
|
@ -1192,25 +1192,30 @@ class GenerationConfig(PushToHubMixin):
|
||||
"""
|
||||
config_dict = model_config.to_dict()
|
||||
config_dict.pop("_from_model_config", None)
|
||||
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
||||
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
||||
|
||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||
# generation config.
|
||||
for decoder_name in ("decoder", "generator", "text_config"):
|
||||
if decoder_name in config_dict:
|
||||
default_generation_config = GenerationConfig()
|
||||
decoder_config = config_dict[decoder_name]
|
||||
for attr in config.to_dict().keys():
|
||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
||||
setattr(config, attr, decoder_config[attr])
|
||||
# generation config (which in turn is defined from the outer attributes of model config).
|
||||
decoder_config = model_config.get_text_config(decoder=True)
|
||||
if decoder_config is not model_config:
|
||||
default_generation_config = GenerationConfig()
|
||||
decoder_config_dict = decoder_config.to_dict()
|
||||
for attr in generation_config.to_dict().keys():
|
||||
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
||||
if attr in decoder_config_dict and is_unset:
|
||||
setattr(generation_config, attr, decoder_config_dict[attr])
|
||||
|
||||
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
||||
if config.return_dict_in_generate is False:
|
||||
if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags):
|
||||
config.return_dict_in_generate = True
|
||||
if generation_config.return_dict_in_generate is False:
|
||||
if any(
|
||||
getattr(generation_config, extra_output_flag, False)
|
||||
for extra_output_flag in generation_config.extra_output_flags
|
||||
):
|
||||
generation_config.return_dict_in_generate = True
|
||||
|
||||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||
return config
|
||||
# Hash to detect whether the instance was modified
|
||||
generation_config._original_object_hash = hash(generation_config)
|
||||
return generation_config
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
|
@ -209,10 +209,7 @@ def get_modules_to_fuse(model, quantization_config):
|
||||
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
|
||||
|
||||
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
|
||||
if not hasattr(model.config, "text_config"):
|
||||
config = model.config
|
||||
else:
|
||||
config = model.config.text_config
|
||||
config = model.config.get_text_config(decoder=True)
|
||||
|
||||
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
|
||||
hidden_size = config.hidden_size
|
||||
@ -345,11 +342,8 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
|
||||
previous_device = gate_proj.qweight.device
|
||||
|
||||
# Deal also with the case model has `text_config` attribute
|
||||
hidden_act = (
|
||||
model.config.hidden_act
|
||||
if not hasattr(model.config, "text_config")
|
||||
else model.config.text_config.hidden_act
|
||||
)
|
||||
config = model.config.get_text_config(decoder=True)
|
||||
hidden_act = config.hidden_act
|
||||
activation_fn = ACT2FN[hidden_act]
|
||||
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
|
||||
|
||||
|
@ -2025,11 +2025,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
vocab_size = model_embeds.weight.shape[0]
|
||||
|
||||
# Update base model and current model config
|
||||
if hasattr(self.config, "text_config"):
|
||||
self.config.text_config.vocab_size = vocab_size
|
||||
else:
|
||||
self.config.vocab_size = vocab_size
|
||||
# Update base model and current model config.
|
||||
self.config.get_text_config().vocab_size = vocab_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Tie weights again if needed
|
||||
|
@ -735,7 +735,7 @@ class ClvpPreTrainedModel(PreTrainedModel):
|
||||
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
|
||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||
elif isinstance(module, ClvpEncoder):
|
||||
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
|
||||
config = self.config.get_text_config()
|
||||
factor = config.initializer_factor
|
||||
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
||||
elif isinstance(module, ClvpConditioningEncoder):
|
||||
|
@ -1330,7 +1330,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=0,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
@ -1381,6 +1381,9 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if num_logits_to_keep is not None:
|
||||
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
@ -1388,7 +1391,6 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"num_logits_to_keep": num_logits_to_keep,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -831,7 +831,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = config.vocab_size
|
||||
max_id = config.get_text_config(decoder=True).vocab_size
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
@ -889,7 +889,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = model.config.vocab_size
|
||||
max_id = model.config.get_text_config(decoder=True).vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
@ -2012,18 +2012,20 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(output.past_key_values is None)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
expected_shape = (batch_size, config.vocab_size)
|
||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||
expected_shape = (batch_size, vocab_size)
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertEqual(len(scores), length)
|
||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||
|
||||
def _check_logits(self, batch_size, scores, config):
|
||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||
vocab_diff = config.vocab_size - scores[0].shape[-1]
|
||||
vocab_diff = vocab_size - scores[0].shape[-1]
|
||||
self.assertTrue(vocab_diff in [0, 1])
|
||||
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
|
@ -1747,12 +1747,13 @@ class ModelTesterMixin:
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
@ -1764,18 +1765,15 @@ class ModelTesterMixin:
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
@ -1787,11 +1785,7 @@ class ModelTesterMixin:
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||
@ -1817,21 +1811,13 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
||||
@ -1852,13 +1838,10 @@ class ModelTesterMixin:
|
||||
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_config.tie_word_embeddings = False
|
||||
|
||||
# if model cannot untied embeddings -> leave test
|
||||
@ -1874,13 +1857,9 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||
@ -1892,11 +1871,7 @@ class ModelTesterMixin:
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
output_embeds = model.get_output_embeddings()
|
||||
@ -1988,7 +1963,7 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
# Check that after resize they remain tied.
|
||||
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model_tied.resize_token_embeddings(vocab_size + 10)
|
||||
params_tied_2 = list(model_tied.parameters())
|
||||
self.assertEqual(len(params_tied_2), len(params_tied))
|
||||
@ -4831,7 +4806,7 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.vocab_size
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
|
@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
|
||||
# Avoid `IndexError` in embedding layers
|
||||
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
|
||||
if tokenizer is not None:
|
||||
config_vocab_size = getattr(model.config, "vocab_size", None)
|
||||
# Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen
|
||||
config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None)
|
||||
# For CLIP-like models
|
||||
if config_vocab_size is None:
|
||||
if hasattr(model.config, "text_config"):
|
||||
if hasattr(model.config, "text_encoder"):
|
||||
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
|
||||
elif hasattr(model.config, "text_encoder"):
|
||||
config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None)
|
||||
|
||||
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
|
||||
raise ValueError(
|
||||
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
|
||||
|
Loading…
Reference in New Issue
Block a user