mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
Add strong test for configuration attributes (#14000)
* Add strong test for configuration attributes * Add fake modif to trigger all tests * Add a better fake modif * Ignore is_encoder_decoder * Fix faulty configs * Remove fake modif
This commit is contained in:
parent
0ef61d392c
commit
f2002fea11
@ -166,7 +166,6 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
self.src_vocab_size = src_vocab_size
|
self.src_vocab_size = src_vocab_size
|
||||||
self.tgt_vocab_size = tgt_vocab_size
|
self.tgt_vocab_size = tgt_vocab_size
|
||||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
self.encoder_ffn_dim = encoder_ffn_dim
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||||
@ -180,10 +179,6 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
self.init_std = init_std # Normal(0, this parameter)
|
self.init_std = init_std # Normal(0, this parameter)
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
|
|
||||||
self.num_beams = num_beams
|
|
||||||
self.length_penalty = length_penalty
|
|
||||||
self.early_stopping = early_stopping
|
|
||||||
|
|
||||||
self.decoder = DecoderConfig(vocab_size=tgt_vocab_size, bos_token_id=eos_token_id)
|
self.decoder = DecoderConfig(vocab_size=tgt_vocab_size, bos_token_id=eos_token_id)
|
||||||
if "decoder" in common_kwargs:
|
if "decoder" in common_kwargs:
|
||||||
del common_kwargs["decoder"]
|
del common_kwargs["decoder"]
|
||||||
@ -204,6 +199,10 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
forced_eos_token_id=forced_eos_token_id,
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
|
num_beams=num_beams,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
early_stopping=early_stopping,
|
||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -131,7 +131,6 @@ class LxmertConfig(PretrainedConfig):
|
|||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
pad_token_id=0,
|
|
||||||
l_layers=9,
|
l_layers=9,
|
||||||
x_layers=5,
|
x_layers=5,
|
||||||
r_layers=5,
|
r_layers=5,
|
||||||
@ -145,8 +144,6 @@ class LxmertConfig(PretrainedConfig):
|
|||||||
visual_obj_loss=True,
|
visual_obj_loss=True,
|
||||||
visual_attr_loss=True,
|
visual_attr_loss=True,
|
||||||
visual_feat_loss=True,
|
visual_feat_loss=True,
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -176,7 +173,5 @@ class LxmertConfig(PretrainedConfig):
|
|||||||
self.visual_obj_loss = visual_obj_loss
|
self.visual_obj_loss = visual_obj_loss
|
||||||
self.visual_attr_loss = visual_attr_loss
|
self.visual_attr_loss = visual_attr_loss
|
||||||
self.visual_feat_loss = visual_feat_loss
|
self.visual_feat_loss = visual_feat_loss
|
||||||
self.output_hidden_states = output_hidden_states
|
|
||||||
self.output_attentions = output_attentions
|
|
||||||
self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
|
self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -21,10 +21,62 @@ import unittest
|
|||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import BertConfig, GPT2Config
|
from transformers import BertConfig, GPT2Config, is_torch_available
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
|
config_common_kwargs = {
|
||||||
|
"return_dict": False,
|
||||||
|
"output_hidden_states": True,
|
||||||
|
"output_attentions": True,
|
||||||
|
"torchscript": True,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"use_bfloat16": True,
|
||||||
|
"pruned_heads": {"a": 1},
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"is_decoder": True,
|
||||||
|
"cross_attention_hidden_size": 128,
|
||||||
|
"add_cross_attention": True,
|
||||||
|
"tie_encoder_decoder": True,
|
||||||
|
"max_length": 50,
|
||||||
|
"min_length": 3,
|
||||||
|
"do_sample": True,
|
||||||
|
"early_stopping": True,
|
||||||
|
"num_beams": 3,
|
||||||
|
"num_beam_groups": 3,
|
||||||
|
"diversity_penalty": 0.5,
|
||||||
|
"temperature": 2.0,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_p": 0.7,
|
||||||
|
"repetition_penalty": 0.8,
|
||||||
|
"length_penalty": 0.8,
|
||||||
|
"no_repeat_ngram_size": 5,
|
||||||
|
"encoder_no_repeat_ngram_size": 5,
|
||||||
|
"bad_words_ids": [1, 2, 3],
|
||||||
|
"num_return_sequences": 3,
|
||||||
|
"chunk_size_feed_forward": 5,
|
||||||
|
"output_scores": True,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"forced_bos_token_id": 2,
|
||||||
|
"forced_eos_token_id": 3,
|
||||||
|
"remove_invalid_values": True,
|
||||||
|
"architectures": ["BertModel"],
|
||||||
|
"finetuning_task": "translation",
|
||||||
|
"id2label": {0: "label"},
|
||||||
|
"label2id": {"label": "0"},
|
||||||
|
"tokenizer_class": "BertTokenizerFast",
|
||||||
|
"prefix": "prefix",
|
||||||
|
"bos_token_id": 6,
|
||||||
|
"pad_token_id": 7,
|
||||||
|
"eos_token_id": 8,
|
||||||
|
"sep_token_id": 9,
|
||||||
|
"decoder_start_token_id": 10,
|
||||||
|
"task_specific_params": {"translation": "some_params"},
|
||||||
|
"problem_type": "regression",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConfigTester(object):
|
class ConfigTester(object):
|
||||||
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
|
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@ -108,6 +160,26 @@ class ConfigTester(object):
|
|||||||
config = self.config_class()
|
config = self.config_class()
|
||||||
self.parent.assertIsNotNone(config)
|
self.parent.assertIsNotNone(config)
|
||||||
|
|
||||||
|
def check_config_arguments_init(self):
|
||||||
|
kwargs = copy.deepcopy(config_common_kwargs)
|
||||||
|
config = self.config_class(**kwargs)
|
||||||
|
wrong_values = []
|
||||||
|
for key, value in config_common_kwargs.items():
|
||||||
|
if key == "torch_dtype":
|
||||||
|
if not is_torch_available():
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if config.torch_dtype != torch.float16:
|
||||||
|
wrong_values.append(("torch_dtype", config.torch_dtype, torch.float16))
|
||||||
|
elif getattr(config, key) != value:
|
||||||
|
wrong_values.append((key, getattr(config, key), value))
|
||||||
|
|
||||||
|
if len(wrong_values) > 0:
|
||||||
|
errors = "\n".join([f"- {v[0]}: got {v[1]} instead of {v[2]}" for v in wrong_values])
|
||||||
|
raise ValueError(f"The following keys were not properly sey in the config:\n{errors}")
|
||||||
|
|
||||||
def run_common_tests(self):
|
def run_common_tests(self):
|
||||||
self.create_and_test_config_common_properties()
|
self.create_and_test_config_common_properties()
|
||||||
self.create_and_test_config_to_json_string()
|
self.create_and_test_config_to_json_string()
|
||||||
@ -115,6 +187,7 @@ class ConfigTester(object):
|
|||||||
self.create_and_test_config_from_and_save_pretrained()
|
self.create_and_test_config_from_and_save_pretrained()
|
||||||
self.create_and_test_config_with_num_labels()
|
self.create_and_test_config_with_num_labels()
|
||||||
self.check_config_can_be_init_without_params()
|
self.check_config_can_be_init_without_params()
|
||||||
|
self.check_config_arguments_init()
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
@ -183,3 +256,15 @@ class ConfigTestUtils(unittest.TestCase):
|
|||||||
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||||
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||||
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||||
|
|
||||||
|
def test_config_common_kwargs_is_complete(self):
|
||||||
|
base_config = PretrainedConfig()
|
||||||
|
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||||
|
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||||
|
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
|
||||||
|
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||||
|
if len(keys_with_defaults) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
|
||||||
|
f"pick another value for them: {', '.join(keys_with_defaults)}."
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user