mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Config: unified logic to retrieve text config (#33219)
This commit is contained in:
parent
ebbe8d8014
commit
d750b509fc
@ -1019,17 +1019,17 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
non_default_generation_parameters = {}
|
non_default_generation_parameters = {}
|
||||||
decoder_attribute_name = None
|
decoder_attribute_name = None
|
||||||
default_config = None
|
|
||||||
|
|
||||||
# Composite models don't have a default config, use their decoder config as a fallback for default values
|
# Composite models don't have a default config, use their decoder config as a fallback for default values
|
||||||
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
|
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
|
||||||
try:
|
try:
|
||||||
default_config = self.__class__()
|
default_config = self.__class__()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
for decoder_attribute_name in ("decoder", "generator", "text_config"):
|
decoder_config = self.get_text_config(decoder=True)
|
||||||
if hasattr(self, decoder_attribute_name):
|
if decoder_config is not self:
|
||||||
default_config = getattr(self, decoder_attribute_name).__class__()
|
default_config = decoder_config.__class__()
|
||||||
break
|
else:
|
||||||
|
decoder_config = None
|
||||||
|
|
||||||
# If it is a composite model, we want to check the subconfig that will be used for generation
|
# If it is a composite model, we want to check the subconfig that will be used for generation
|
||||||
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
|
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
|
||||||
@ -1057,6 +1057,36 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
return non_default_generation_parameters
|
return non_default_generation_parameters
|
||||||
|
|
||||||
|
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||||
|
"""
|
||||||
|
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||||
|
itself. On specific composite models, it is under a set of valid names.
|
||||||
|
|
||||||
|
If `decoder` is set to `True`, then only search for decoder config names.
|
||||||
|
"""
|
||||||
|
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||||
|
encoder_possible_text_config_names = ("text_encoder",)
|
||||||
|
if decoder:
|
||||||
|
possible_text_config_names = decoder_possible_text_config_names
|
||||||
|
else:
|
||||||
|
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
|
||||||
|
|
||||||
|
valid_text_config_names = []
|
||||||
|
for text_config_name in possible_text_config_names:
|
||||||
|
if hasattr(self, text_config_name):
|
||||||
|
text_config = getattr(self, text_config_name, None)
|
||||||
|
if text_config is not None:
|
||||||
|
valid_text_config_names += [text_config_name]
|
||||||
|
|
||||||
|
if len(valid_text_config_names) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
|
||||||
|
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||||
|
)
|
||||||
|
elif len(valid_text_config_names) == 1:
|
||||||
|
return getattr(self, valid_text_config_names[0])
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def get_configuration_file(configuration_files: List[str]) -> str:
|
def get_configuration_file(configuration_files: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -1192,25 +1192,30 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
config_dict = model_config.to_dict()
|
config_dict = model_config.to_dict()
|
||||||
config_dict.pop("_from_model_config", None)
|
config_dict.pop("_from_model_config", None)
|
||||||
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
||||||
|
|
||||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||||
# generation config.
|
# generation config (which in turn is defined from the outer attributes of model config).
|
||||||
for decoder_name in ("decoder", "generator", "text_config"):
|
decoder_config = model_config.get_text_config(decoder=True)
|
||||||
if decoder_name in config_dict:
|
if decoder_config is not model_config:
|
||||||
default_generation_config = GenerationConfig()
|
default_generation_config = GenerationConfig()
|
||||||
decoder_config = config_dict[decoder_name]
|
decoder_config_dict = decoder_config.to_dict()
|
||||||
for attr in config.to_dict().keys():
|
for attr in generation_config.to_dict().keys():
|
||||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
||||||
setattr(config, attr, decoder_config[attr])
|
if attr in decoder_config_dict and is_unset:
|
||||||
|
setattr(generation_config, attr, decoder_config_dict[attr])
|
||||||
|
|
||||||
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
||||||
if config.return_dict_in_generate is False:
|
if generation_config.return_dict_in_generate is False:
|
||||||
if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags):
|
if any(
|
||||||
config.return_dict_in_generate = True
|
getattr(generation_config, extra_output_flag, False)
|
||||||
|
for extra_output_flag in generation_config.extra_output_flags
|
||||||
|
):
|
||||||
|
generation_config.return_dict_in_generate = True
|
||||||
|
|
||||||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
# Hash to detect whether the instance was modified
|
||||||
return config
|
generation_config._original_object_hash = hash(generation_config)
|
||||||
|
return generation_config
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def update(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -209,10 +209,7 @@ def get_modules_to_fuse(model, quantization_config):
|
|||||||
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
|
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
|
||||||
|
|
||||||
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
|
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
|
||||||
if not hasattr(model.config, "text_config"):
|
config = model.config.get_text_config(decoder=True)
|
||||||
config = model.config
|
|
||||||
else:
|
|
||||||
config = model.config.text_config
|
|
||||||
|
|
||||||
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
|
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -345,11 +342,8 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
|
|||||||
previous_device = gate_proj.qweight.device
|
previous_device = gate_proj.qweight.device
|
||||||
|
|
||||||
# Deal also with the case model has `text_config` attribute
|
# Deal also with the case model has `text_config` attribute
|
||||||
hidden_act = (
|
config = model.config.get_text_config(decoder=True)
|
||||||
model.config.hidden_act
|
hidden_act = config.hidden_act
|
||||||
if not hasattr(model.config, "text_config")
|
|
||||||
else model.config.text_config.hidden_act
|
|
||||||
)
|
|
||||||
activation_fn = ACT2FN[hidden_act]
|
activation_fn = ACT2FN[hidden_act]
|
||||||
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
|
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
|
||||||
|
|
||||||
|
@ -2025,11 +2025,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
vocab_size = model_embeds.weight.shape[0]
|
vocab_size = model_embeds.weight.shape[0]
|
||||||
|
|
||||||
# Update base model and current model config
|
# Update base model and current model config.
|
||||||
if hasattr(self.config, "text_config"):
|
self.config.get_text_config().vocab_size = vocab_size
|
||||||
self.config.text_config.vocab_size = vocab_size
|
|
||||||
else:
|
|
||||||
self.config.vocab_size = vocab_size
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
# Tie weights again if needed
|
# Tie weights again if needed
|
||||||
|
@ -735,7 +735,7 @@ class ClvpPreTrainedModel(PreTrainedModel):
|
|||||||
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
|
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
|
||||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||||
elif isinstance(module, ClvpEncoder):
|
elif isinstance(module, ClvpEncoder):
|
||||||
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
|
config = self.config.get_text_config()
|
||||||
factor = config.initializer_factor
|
factor = config.initializer_factor
|
||||||
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
||||||
elif isinstance(module, ClvpConditioningEncoder):
|
elif isinstance(module, ClvpConditioningEncoder):
|
||||||
|
@ -1330,7 +1330,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
|
|||||||
cache_position=None,
|
cache_position=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_logits_to_keep=0,
|
num_logits_to_keep=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
@ -1381,6 +1381,9 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
@ -1388,7 +1391,6 @@ class OlmoeForCausalLM(OlmoePreTrainedModel):
|
|||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"num_logits_to_keep": num_logits_to_keep,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
@ -831,7 +831,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
min_id = 3
|
min_id = 3
|
||||||
max_id = config.vocab_size
|
max_id = config.get_text_config(decoder=True).vocab_size
|
||||||
|
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
@ -889,7 +889,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
min_id = 3
|
min_id = 3
|
||||||
max_id = model.config.vocab_size
|
max_id = model.config.get_text_config(decoder=True).vocab_size
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
@ -2012,18 +2012,20 @@ class GenerationTesterMixin:
|
|||||||
self.assertTrue(output.past_key_values is None)
|
self.assertTrue(output.past_key_values is None)
|
||||||
|
|
||||||
def _check_scores(self, batch_size, scores, length, config):
|
def _check_scores(self, batch_size, scores, length, config):
|
||||||
expected_shape = (batch_size, config.vocab_size)
|
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||||
|
expected_shape = (batch_size, vocab_size)
|
||||||
self.assertIsInstance(scores, tuple)
|
self.assertIsInstance(scores, tuple)
|
||||||
self.assertEqual(len(scores), length)
|
self.assertEqual(len(scores), length)
|
||||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||||
|
|
||||||
def _check_logits(self, batch_size, scores, config):
|
def _check_logits(self, batch_size, scores, config):
|
||||||
|
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||||
self.assertIsInstance(scores, tuple)
|
self.assertIsInstance(scores, tuple)
|
||||||
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
||||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||||
vocab_diff = config.vocab_size - scores[0].shape[-1]
|
vocab_diff = vocab_size - scores[0].shape[-1]
|
||||||
self.assertTrue(vocab_diff in [0, 1])
|
self.assertTrue(vocab_diff in [0, 1])
|
||||||
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||||
|
@ -1747,12 +1747,13 @@ class ModelTesterMixin:
|
|||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
|
if not self.test_resize_embeddings:
|
||||||
|
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||||
|
|
||||||
(
|
(
|
||||||
original_config,
|
original_config,
|
||||||
inputs_dict,
|
inputs_dict,
|
||||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.test_resize_embeddings:
|
|
||||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config = copy.deepcopy(original_config)
|
config = copy.deepcopy(original_config)
|
||||||
@ -1764,18 +1765,15 @@ class ModelTesterMixin:
|
|||||||
if self.model_tester.is_training is False:
|
if self.model_tester.is_training is False:
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
model_vocab_size = config.get_text_config().vocab_size
|
||||||
# Retrieve the embeddings and clone theme
|
# Retrieve the embeddings and clone theme
|
||||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||||
cloned_embeddings = model_embed.weight.clone()
|
cloned_embeddings = model_embed.weight.clone()
|
||||||
|
|
||||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
new_model_vocab_size = (
|
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||||
model.config.text_config.vocab_size
|
|
||||||
if hasattr(model.config, "text_config")
|
|
||||||
else model.config.vocab_size
|
|
||||||
)
|
|
||||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||||
# Check that it actually resizes the embeddings matrix
|
# Check that it actually resizes the embeddings matrix
|
||||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||||
@ -1787,11 +1785,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
new_model_vocab_size = (
|
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||||
model.config.text_config.vocab_size
|
|
||||||
if hasattr(model.config, "text_config")
|
|
||||||
else model.config.vocab_size
|
|
||||||
)
|
|
||||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||||
# Check that it actually resizes the embeddings matrix
|
# Check that it actually resizes the embeddings matrix
|
||||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||||
@ -1817,21 +1811,13 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
model_vocab_size = config.get_text_config().vocab_size
|
||||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||||
new_model_vocab_size = (
|
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||||
model.config.text_config.vocab_size
|
|
||||||
if hasattr(model.config, "text_config")
|
|
||||||
else model.config.vocab_size
|
|
||||||
)
|
|
||||||
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
||||||
|
|
||||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||||
new_model_vocab_size = (
|
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||||
model.config.text_config.vocab_size
|
|
||||||
if hasattr(model.config, "text_config")
|
|
||||||
else model.config.vocab_size
|
|
||||||
)
|
|
||||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||||
|
|
||||||
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
||||||
@ -1852,13 +1838,10 @@ class ModelTesterMixin:
|
|||||||
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||||
|
|
||||||
def test_resize_embeddings_untied(self):
|
def test_resize_embeddings_untied(self):
|
||||||
(
|
|
||||||
original_config,
|
|
||||||
inputs_dict,
|
|
||||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
if not self.test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||||
|
|
||||||
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
original_config.tie_word_embeddings = False
|
original_config.tie_word_embeddings = False
|
||||||
|
|
||||||
# if model cannot untied embeddings -> leave test
|
# if model cannot untied embeddings -> leave test
|
||||||
@ -1874,13 +1857,9 @@ class ModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
model_vocab_size = config.get_text_config().vocab_size
|
||||||
model.resize_token_embeddings(model_vocab_size + 10)
|
model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
new_model_vocab_size = (
|
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||||
model.config.text_config.vocab_size
|
|
||||||
if hasattr(model.config, "text_config")
|
|
||||||
else model.config.vocab_size
|
|
||||||
)
|
|
||||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||||
output_embeds = model.get_output_embeddings()
|
output_embeds = model.get_output_embeddings()
|
||||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||||
@ -1892,11 +1871,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
model.resize_token_embeddings(model_vocab_size - 15)
|
model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
new_model_vocab_size = (
|
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||||
model.config.text_config.vocab_size
|
|
||||||
if hasattr(model.config, "text_config")
|
|
||||||
else model.config.vocab_size
|
|
||||||
)
|
|
||||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||||
# Check that it actually resizes the embeddings matrix
|
# Check that it actually resizes the embeddings matrix
|
||||||
output_embeds = model.get_output_embeddings()
|
output_embeds = model.get_output_embeddings()
|
||||||
@ -1988,7 +1963,7 @@ class ModelTesterMixin:
|
|||||||
# self.assertTrue(check_same_values(embeddings, decoding))
|
# self.assertTrue(check_same_values(embeddings, decoding))
|
||||||
|
|
||||||
# Check that after resize they remain tied.
|
# Check that after resize they remain tied.
|
||||||
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
vocab_size = config.get_text_config().vocab_size
|
||||||
model_tied.resize_token_embeddings(vocab_size + 10)
|
model_tied.resize_token_embeddings(vocab_size + 10)
|
||||||
params_tied_2 = list(model_tied.parameters())
|
params_tied_2 = list(model_tied.parameters())
|
||||||
self.assertEqual(len(params_tied_2), len(params_tied))
|
self.assertEqual(len(params_tied_2), len(params_tied))
|
||||||
@ -4831,7 +4806,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
batch_size, sequence_length = inputs["input_ids"].shape
|
batch_size, sequence_length = inputs["input_ids"].shape
|
||||||
vocab_size = config.vocab_size
|
vocab_size = config.get_text_config().vocab_size
|
||||||
model = model_class(config).to(device=torch_device).eval()
|
model = model_class(config).to(device=torch_device).eval()
|
||||||
|
|
||||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||||
|
@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
|
|||||||
# Avoid `IndexError` in embedding layers
|
# Avoid `IndexError` in embedding layers
|
||||||
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
|
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
config_vocab_size = getattr(model.config, "vocab_size", None)
|
# Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen
|
||||||
|
config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None)
|
||||||
# For CLIP-like models
|
# For CLIP-like models
|
||||||
if config_vocab_size is None:
|
if config_vocab_size is None:
|
||||||
if hasattr(model.config, "text_config"):
|
if hasattr(model.config, "text_encoder"):
|
||||||
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
|
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
|
||||||
elif hasattr(model.config, "text_encoder"):
|
|
||||||
config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None)
|
|
||||||
|
|
||||||
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
|
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
|
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
|
||||||
|
Loading…
Reference in New Issue
Block a user