mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add tests.
Maybe not the best possible place for the tests, lmk.
This commit is contained in:
parent
18e1f751f1
commit
4f15e5a267
@ -22,7 +22,7 @@ import logging
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import require_torch, slow
|
||||
from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (AutoConfig, BertConfig,
|
||||
@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForQuestionAnswering)
|
||||
|
||||
def test_from_pretrained_identifier(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -22,7 +22,7 @@ import logging
|
||||
|
||||
from transformers import is_tf_available
|
||||
|
||||
from .utils import require_tf, slow
|
||||
from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER
|
||||
|
||||
if is_tf_available():
|
||||
from transformers import (AutoConfig, BertConfig,
|
||||
@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TFBertForQuestionAnswering)
|
||||
|
||||
def test_from_pretrained_identifier(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -23,7 +23,7 @@ import logging
|
||||
from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
|
||||
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
from .utils import slow
|
||||
from .utils import slow, SMALL_MODEL_IDENTIFIER
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertIsInstance(tokenizer, GPT2Tokenizer)
|
||||
self.assertGreater(len(tokenizer), 0)
|
||||
|
||||
def test_tokenizer_from_pretrained_identifier(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||
self.assertEqual(len(tokenizer), 12)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -6,6 +6,9 @@ from distutils.util import strtobool
|
||||
from transformers.file_utils import _tf_available, _torch_available
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
|
||||
|
||||
try:
|
||||
run_slow = os.environ["RUN_SLOW"]
|
||||
except KeyError:
|
||||
|
Loading…
Reference in New Issue
Block a user