Add strong test for configuration attributes (#14000)

* Add strong test for configuration attributes

* Add fake modif to trigger all tests

* Add a better fake modif

* Ignore is_encoder_decoder

* Fix faulty configs

* Remove fake modif
This commit is contained in:
Sylvain Gugger 2021-10-14 09:07:08 -04:00 committed by GitHub
parent 0ef61d392c
commit f2002fea11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 12 deletions

View File

@ -166,7 +166,6 @@ class FSMTConfig(PretrainedConfig):
self.src_vocab_size = src_vocab_size
self.tgt_vocab_size = tgt_vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.max_length = max_length
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
@ -180,10 +179,6 @@ class FSMTConfig(PretrainedConfig):
self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
self.num_beams = num_beams
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.decoder = DecoderConfig(vocab_size=tgt_vocab_size, bos_token_id=eos_token_id)
if "decoder" in common_kwargs:
del common_kwargs["decoder"]
@ -204,6 +199,10 @@ class FSMTConfig(PretrainedConfig):
is_encoder_decoder=is_encoder_decoder,
tie_word_embeddings=tie_word_embeddings,
forced_eos_token_id=forced_eos_token_id,
max_length=max_length,
num_beams=num_beams,
length_penalty=length_penalty,
early_stopping=early_stopping,
**common_kwargs,
)

View File

@ -131,7 +131,6 @@ class LxmertConfig(PretrainedConfig):
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
l_layers=9,
x_layers=5,
r_layers=5,
@ -145,8 +144,6 @@ class LxmertConfig(PretrainedConfig):
visual_obj_loss=True,
visual_attr_loss=True,
visual_feat_loss=True,
output_attentions=False,
output_hidden_states=False,
**kwargs,
):
self.vocab_size = vocab_size
@ -176,7 +173,5 @@ class LxmertConfig(PretrainedConfig):
self.visual_obj_loss = visual_obj_loss
self.visual_attr_loss = visual_attr_loss
self.visual_feat_loss = visual_feat_loss
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions
self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
super().__init__(**kwargs)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import tempfile
@ -21,10 +21,62 @@ import unittest
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import BertConfig, GPT2Config
from transformers import BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
config_common_kwargs = {
"return_dict": False,
"output_hidden_states": True,
"output_attentions": True,
"torchscript": True,
"torch_dtype": "float16",
"use_bfloat16": True,
"pruned_heads": {"a": 1},
"tie_word_embeddings": False,
"is_decoder": True,
"cross_attention_hidden_size": 128,
"add_cross_attention": True,
"tie_encoder_decoder": True,
"max_length": 50,
"min_length": 3,
"do_sample": True,
"early_stopping": True,
"num_beams": 3,
"num_beam_groups": 3,
"diversity_penalty": 0.5,
"temperature": 2.0,
"top_k": 10,
"top_p": 0.7,
"repetition_penalty": 0.8,
"length_penalty": 0.8,
"no_repeat_ngram_size": 5,
"encoder_no_repeat_ngram_size": 5,
"bad_words_ids": [1, 2, 3],
"num_return_sequences": 3,
"chunk_size_feed_forward": 5,
"output_scores": True,
"return_dict_in_generate": True,
"forced_bos_token_id": 2,
"forced_eos_token_id": 3,
"remove_invalid_values": True,
"architectures": ["BertModel"],
"finetuning_task": "translation",
"id2label": {0: "label"},
"label2id": {"label": "0"},
"tokenizer_class": "BertTokenizerFast",
"prefix": "prefix",
"bos_token_id": 6,
"pad_token_id": 7,
"eos_token_id": 8,
"sep_token_id": 9,
"decoder_start_token_id": 10,
"task_specific_params": {"translation": "some_params"},
"problem_type": "regression",
}
class ConfigTester(object):
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
self.parent = parent
@ -108,6 +160,26 @@ class ConfigTester(object):
config = self.config_class()
self.parent.assertIsNotNone(config)
def check_config_arguments_init(self):
kwargs = copy.deepcopy(config_common_kwargs)
config = self.config_class(**kwargs)
wrong_values = []
for key, value in config_common_kwargs.items():
if key == "torch_dtype":
if not is_torch_available():
continue
else:
import torch
if config.torch_dtype != torch.float16:
wrong_values.append(("torch_dtype", config.torch_dtype, torch.float16))
elif getattr(config, key) != value:
wrong_values.append((key, getattr(config, key), value))
if len(wrong_values) > 0:
errors = "\n".join([f"- {v[0]}: got {v[1]} instead of {v[2]}" for v in wrong_values])
raise ValueError(f"The following keys were not properly sey in the config:\n{errors}")
def run_common_tests(self):
self.create_and_test_config_common_properties()
self.create_and_test_config_to_json_string()
@ -115,6 +187,7 @@ class ConfigTester(object):
self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params()
self.check_config_arguments_init()
@is_staging_test
@ -183,3 +256,15 @@ class ConfigTestUtils(unittest.TestCase):
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
def test_config_common_kwargs_is_complete(self):
base_config = PretrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
f"pick another value for them: {', '.join(keys_with_defaults)}."
)