mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update tests
This commit is contained in:
parent
56e98ba81a
commit
d3418a94ff
@ -16,15 +16,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
import random
|
||||
import uuid
|
||||
import tempfile
|
||||
|
||||
import unittest
|
||||
import logging
|
||||
from .tokenization_tests_commons import TemporaryDirectory
|
||||
|
||||
|
||||
class ConfigTester(object):
|
||||
@ -48,16 +45,28 @@ class ConfigTester(object):
|
||||
|
||||
def create_and_test_config_to_json_file(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = self.config_class.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "config.json")
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = self.config_class.from_json_file(json_file_path)
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_from_and_save_pretrained(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
config_first.save_pretrained(tmpdirname)
|
||||
config_second = self.config_class.from_pretrained(tmpdirname)
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_common_tests(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.create_and_test_config_to_json_string()
|
||||
self.create_and_test_config_to_json_file()
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -15,10 +15,7 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from transformers.model_card import ModelCard
|
||||
@ -50,10 +47,6 @@ class ModelCardTester(unittest.TestCase):
|
||||
'ROUGE-1': 76,
|
||||
},
|
||||
}
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_model_card_common_properties(self):
|
||||
model_card = ModelCard.from_dict(self.inputs_dict)
|
||||
@ -83,5 +76,14 @@ class ModelCardTester(unittest.TestCase):
|
||||
|
||||
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
|
||||
|
||||
def test_model_card_from_and_save_pretrained(self):
|
||||
model_card_first = ModelCard.from_dict(self.inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
model_card_first.save_pretrained(tmpdirname)
|
||||
model_card_second = ModelCard.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user