mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add strong test for configuration attributes (#14000)
* Add strong test for configuration attributes * Add fake modif to trigger all tests * Add a better fake modif * Ignore is_encoder_decoder * Fix faulty configs * Remove fake modif
This commit is contained in:
parent
0ef61d392c
commit
f2002fea11
@ -166,7 +166,6 @@ class FSMTConfig(PretrainedConfig):
|
||||
self.src_vocab_size = src_vocab_size
|
||||
self.tgt_vocab_size = tgt_vocab_size
|
||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||
self.max_length = max_length
|
||||
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||
@ -180,10 +179,6 @@ class FSMTConfig(PretrainedConfig):
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
self.activation_function = activation_function
|
||||
|
||||
self.num_beams = num_beams
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
|
||||
self.decoder = DecoderConfig(vocab_size=tgt_vocab_size, bos_token_id=eos_token_id)
|
||||
if "decoder" in common_kwargs:
|
||||
del common_kwargs["decoder"]
|
||||
@ -204,6 +199,10 @@ class FSMTConfig(PretrainedConfig):
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
|
@ -131,7 +131,6 @@ class LxmertConfig(PretrainedConfig):
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
l_layers=9,
|
||||
x_layers=5,
|
||||
r_layers=5,
|
||||
@ -145,8 +144,6 @@ class LxmertConfig(PretrainedConfig):
|
||||
visual_obj_loss=True,
|
||||
visual_attr_loss=True,
|
||||
visual_feat_loss=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -176,7 +173,5 @@ class LxmertConfig(PretrainedConfig):
|
||||
self.visual_obj_loss = visual_obj_loss
|
||||
self.visual_attr_loss = visual_attr_loss
|
||||
self.visual_feat_loss = visual_feat_loss
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.output_attentions = output_attentions
|
||||
self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
|
||||
super().__init__(**kwargs)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@ -21,10 +21,62 @@ import unittest
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, GPT2Config
|
||||
from transformers import BertConfig, GPT2Config, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
||||
|
||||
|
||||
config_common_kwargs = {
|
||||
"return_dict": False,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"torchscript": True,
|
||||
"torch_dtype": "float16",
|
||||
"use_bfloat16": True,
|
||||
"pruned_heads": {"a": 1},
|
||||
"tie_word_embeddings": False,
|
||||
"is_decoder": True,
|
||||
"cross_attention_hidden_size": 128,
|
||||
"add_cross_attention": True,
|
||||
"tie_encoder_decoder": True,
|
||||
"max_length": 50,
|
||||
"min_length": 3,
|
||||
"do_sample": True,
|
||||
"early_stopping": True,
|
||||
"num_beams": 3,
|
||||
"num_beam_groups": 3,
|
||||
"diversity_penalty": 0.5,
|
||||
"temperature": 2.0,
|
||||
"top_k": 10,
|
||||
"top_p": 0.7,
|
||||
"repetition_penalty": 0.8,
|
||||
"length_penalty": 0.8,
|
||||
"no_repeat_ngram_size": 5,
|
||||
"encoder_no_repeat_ngram_size": 5,
|
||||
"bad_words_ids": [1, 2, 3],
|
||||
"num_return_sequences": 3,
|
||||
"chunk_size_feed_forward": 5,
|
||||
"output_scores": True,
|
||||
"return_dict_in_generate": True,
|
||||
"forced_bos_token_id": 2,
|
||||
"forced_eos_token_id": 3,
|
||||
"remove_invalid_values": True,
|
||||
"architectures": ["BertModel"],
|
||||
"finetuning_task": "translation",
|
||||
"id2label": {0: "label"},
|
||||
"label2id": {"label": "0"},
|
||||
"tokenizer_class": "BertTokenizerFast",
|
||||
"prefix": "prefix",
|
||||
"bos_token_id": 6,
|
||||
"pad_token_id": 7,
|
||||
"eos_token_id": 8,
|
||||
"sep_token_id": 9,
|
||||
"decoder_start_token_id": 10,
|
||||
"task_specific_params": {"translation": "some_params"},
|
||||
"problem_type": "regression",
|
||||
}
|
||||
|
||||
|
||||
class ConfigTester(object):
|
||||
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
|
||||
self.parent = parent
|
||||
@ -108,6 +160,26 @@ class ConfigTester(object):
|
||||
config = self.config_class()
|
||||
self.parent.assertIsNotNone(config)
|
||||
|
||||
def check_config_arguments_init(self):
|
||||
kwargs = copy.deepcopy(config_common_kwargs)
|
||||
config = self.config_class(**kwargs)
|
||||
wrong_values = []
|
||||
for key, value in config_common_kwargs.items():
|
||||
if key == "torch_dtype":
|
||||
if not is_torch_available():
|
||||
continue
|
||||
else:
|
||||
import torch
|
||||
|
||||
if config.torch_dtype != torch.float16:
|
||||
wrong_values.append(("torch_dtype", config.torch_dtype, torch.float16))
|
||||
elif getattr(config, key) != value:
|
||||
wrong_values.append((key, getattr(config, key), value))
|
||||
|
||||
if len(wrong_values) > 0:
|
||||
errors = "\n".join([f"- {v[0]}: got {v[1]} instead of {v[2]}" for v in wrong_values])
|
||||
raise ValueError(f"The following keys were not properly sey in the config:\n{errors}")
|
||||
|
||||
def run_common_tests(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.create_and_test_config_to_json_string()
|
||||
@ -115,6 +187,7 @@ class ConfigTester(object):
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
self.create_and_test_config_with_num_labels()
|
||||
self.check_config_can_be_init_without_params()
|
||||
self.check_config_arguments_init()
|
||||
|
||||
|
||||
@is_staging_test
|
||||
@ -183,3 +256,15 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||
|
||||
def test_config_common_kwargs_is_complete(self):
|
||||
base_config = PretrainedConfig()
|
||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
|
||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||
if len(keys_with_defaults) > 0:
|
||||
raise ValueError(
|
||||
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
|
||||
f"pick another value for them: {', '.join(keys_with_defaults)}."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user