mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix tests - flagged as slow all the tests downloading from AWS
This commit is contained in:
parent
f02805da6f
commit
b340a910ed
@ -173,7 +173,7 @@ class RobertaModel(BertModel):
|
|||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embeddings.word_emebddings = value
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
||||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||||
|
@ -38,6 +38,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class AutoModelTest(unittest.TestCase):
|
class AutoModelTest(unittest.TestCase):
|
||||||
|
@pytest.mark.slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
@ -52,6 +53,7 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
for value in loading_info.values():
|
for value in loading_info.values():
|
||||||
self.assertEqual(len(value), 0)
|
self.assertEqual(len(value), 0)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_lmhead_model_from_pretrained(self):
|
def test_lmhead_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
@ -64,6 +66,7 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForMaskedLM)
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_sequence_classification_model_from_pretrained(self):
|
def test_sequence_classification_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
@ -76,6 +79,7 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForSequenceClassification)
|
self.assertIsInstance(model, BertForSequenceClassification)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_question_answering_model_from_pretrained(self):
|
def test_question_answering_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
@ -429,12 +429,6 @@ class CommonTestCases:
|
|||||||
list(hidden_states[0].shape[-2:]),
|
list(hidden_states[0].shape[-2:]),
|
||||||
[self.model_tester.seq_length, self.model_tester.hidden_size])
|
[self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||||
|
|
||||||
def test_debug(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config)
|
|
||||||
model_embed = model.resize_token_embeddings(config.vocab_size + 10)
|
|
||||||
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
@ -703,6 +697,7 @@ class CommonTestCases:
|
|||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
self.create_and_check_presents(*config_and_inputs)
|
self.create_and_check_presents(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def run_slow_tests(self):
|
def run_slow_tests(self):
|
||||||
self.create_and_check_model_from_pretrained()
|
self.create_and_check_model_from_pretrained()
|
||||||
|
|
||||||
@ -776,6 +771,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|||||||
|
|
||||||
|
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(unittest.TestCase):
|
||||||
|
@pytest.mark.slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
@ -27,6 +27,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class EncoderDecoderModelTest(unittest.TestCase):
|
class EncoderDecoderModelTest(unittest.TestCase):
|
||||||
|
@pytest.mark.slow
|
||||||
def test_model2model_from_pretrained(self):
|
def test_model2model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
@ -26,6 +26,7 @@ from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CON
|
|||||||
|
|
||||||
|
|
||||||
class AutoTokenizerTest(unittest.TestCase):
|
class AutoTokenizerTest(unittest.TestCase):
|
||||||
|
@pytest.mark.slow
|
||||||
def test_tokenizer_from_pretrained(self):
|
def test_tokenizer_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from transformers.tokenization_bert import (BasicTokenizer,
|
from transformers.tokenization_bert import (BasicTokenizer,
|
||||||
@ -125,6 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
self.assertFalse(_is_punctuation(u"A"))
|
self.assertFalse(_is_punctuation(u"A"))
|
||||||
self.assertFalse(_is_punctuation(u" "))
|
self.assertFalse(_is_punctuation(u" "))
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from transformers.tokenization_distilbert import (DistilBertTokenizer)
|
from transformers.tokenization_distilbert import (DistilBertTokenizer)
|
||||||
@ -30,6 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
|||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
|
from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
|
||||||
@ -78,6 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||||
|
|
||||||
|
@ -18,11 +18,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import six
|
import six
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
class TokenizerUtilsTest(unittest.TestCase):
|
class TokenizerUtilsTest(unittest.TestCase):
|
||||||
|
@pytest.mark.slow
|
||||||
def check_tokenizer_from_pretrained(self, tokenizer_class):
|
def check_tokenizer_from_pretrained(self, tokenizer_class):
|
||||||
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
|
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
|
||||||
for model_name in s3_models[:1]:
|
for model_name in s3_models[:1]:
|
||||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
@ -66,6 +67,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
|
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||||
|
|
||||||
@ -89,6 +90,7 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user