diff --git a/transformers/tests/modeling_auto_test.py b/transformers/tests/modeling_auto_test.py index 9b7d920bc86..871a262fe8c 100644 --- a/transformers/tests/modeling_auto_test.py +++ b/transformers/tests/modeling_auto_test.py @@ -22,7 +22,7 @@ import logging from transformers import is_torch_available -from .utils import require_torch, slow +from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER if is_torch_available(): from transformers import (AutoConfig, BertConfig, @@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, BertForQuestionAnswering) + def test_from_pretrained_identifier(self): + logging.basicConfig(level=logging.INFO) + model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) + self.assertIsInstance(model, BertForMaskedLM) + if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/modeling_tf_auto_test.py b/transformers/tests/modeling_tf_auto_test.py index 7ea48015d9b..7ab6eaa3d63 100644 --- a/transformers/tests/modeling_tf_auto_test.py +++ b/transformers/tests/modeling_tf_auto_test.py @@ -22,7 +22,7 @@ import logging from transformers import is_tf_available -from .utils import require_tf, slow +from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER if is_tf_available(): from transformers import (AutoConfig, BertConfig, @@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, TFBertForQuestionAnswering) + def test_from_pretrained_identifier(self): + logging.basicConfig(level=logging.INFO) + model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) + self.assertIsInstance(model, TFBertForMaskedLM) + if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/tokenization_auto_test.py b/transformers/tests/tokenization_auto_test.py index 18346d27688..0a894cac043 100644 --- a/transformers/tests/tokenization_auto_test.py +++ b/transformers/tests/tokenization_auto_test.py @@ -23,7 +23,7 @@ import logging from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP -from .utils import slow +from .utils import slow, SMALL_MODEL_IDENTIFIER class AutoTokenizerTest(unittest.TestCase): @@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase): self.assertIsInstance(tokenizer, GPT2Tokenizer) self.assertGreater(len(tokenizer), 0) + def test_tokenizer_from_pretrained_identifier(self): + logging.basicConfig(level=logging.INFO) + tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) + self.assertIsInstance(tokenizer, BertTokenizer) + self.assertEqual(len(tokenizer), 12) if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/utils.py b/transformers/tests/utils.py index 7a51ab612b6..3aff1daf835 100644 --- a/transformers/tests/utils.py +++ b/transformers/tests/utils.py @@ -6,6 +6,9 @@ from distutils.util import strtobool from transformers.file_utils import _tf_available, _torch_available +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" + + try: run_slow = os.environ["RUN_SLOW"] except KeyError: