mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[tests] Flag to test on cuda
This commit is contained in:
parent
13d9135fa5
commit
27e015bd54
@ -7,6 +7,9 @@ def pytest_addoption(parser):
|
|||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--runslow", action="store_true", default=False, help="run slow tests"
|
"--runslow", action="store_true", default=False, help="run slow tests"
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--use_cuda", action="store_true", default=False, help="run tests on gpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
@ -21,3 +24,8 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
for item in items:
|
for item in items:
|
||||||
if "slow" in item.keywords:
|
if "slow" in item.keywords:
|
||||||
item.add_marker(skip_slow)
|
item.add_marker(skip_slow)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def use_cuda(request):
|
||||||
|
""" Run test on gpu """
|
||||||
|
return request.config.getoption("--use_cuda")
|
||||||
|
@ -35,6 +35,7 @@ else:
|
|||||||
pytestmark = pytest.mark.skip("Require Torch")
|
pytestmark = pytest.mark.skip("Require Torch")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("use_cuda")
|
||||||
class BertModelTest(CommonTestCases.CommonModelTester):
|
class BertModelTest(CommonTestCases.CommonModelTester):
|
||||||
|
|
||||||
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
@ -66,6 +67,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
num_labels=3,
|
num_labels=3,
|
||||||
num_choices=4,
|
num_choices=4,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
device='cpu',
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -89,25 +91,26 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.num_choices = num_choices
|
self.num_choices = num_choices
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
|
||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
|
||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
token_labels = None
|
token_labels = None
|
||||||
choice_labels = None
|
choice_labels = None
|
||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
|
||||||
|
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size_or_config_json_file=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
@ -141,6 +144,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
|
|
||||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertModel(config=config)
|
model = BertModel(config=config)
|
||||||
|
model.to(input_ids.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||||
@ -309,7 +313,10 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
def test_bert_model(self):
|
def test_bert_model(self, use_cuda=False):
|
||||||
|
# ^^ This could be a real fixture
|
||||||
|
if use_cuda:
|
||||||
|
self.model_tester.device = "cuda"
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user