diff --git a/pytorch_pretrained_bert/tests/model_tests_commons.py b/pytorch_pretrained_bert/tests/model_tests_commons.py index 75c0ae19fde..0afda5f2ce9 100644 --- a/pytorch_pretrained_bert/tests/model_tests_commons.py +++ b/pytorch_pretrained_bert/tests/model_tests_commons.py @@ -198,14 +198,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di [tester.seq_length, tester.hidden_size]) -def create_and_check_commons(tester, config, inputs_dict, test_pruning=True): +def create_and_check_commons(tester, config, inputs_dict, test_pruning=True, test_torchscript=True): _create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict) - _create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict) - _create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict) - _create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict) + + if test_torchscript: + _create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict) + _create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict) + _create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict) + if test_pruning: _create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict) diff --git a/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py b/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py index 8b46b6d7556..caeb25b4126 100644 --- a/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py +++ b/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py @@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase): def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels): inputs_dict = {'input_ids': input_ids_1} - create_and_check_commons(self, config, inputs_dict, test_pruning=False) + create_and_check_commons(self, config, inputs_dict, test_pruning=False, test_torchscript=False) def test_default(self): self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))