From 12726f8556152dbc6c115327646ebb33ccb2bc4f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 20 Dec 2019 20:56:58 +0100 Subject: [PATCH] Remove redundant torch.jit.trace in tests. This looks like it could be expensive, so don't run it twice. --- transformers/tests/modeling_common_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index c84162117a0..c03d307e715 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -218,12 +218,11 @@ class CommonTestCases: inputs = inputs_dict['input_ids'] # Let's keep only input_ids try: - torch.jit.trace(model, inputs) + traced_gpt2 = torch.jit.trace(model, inputs) except RuntimeError: self.fail("Couldn't trace module.") try: - traced_gpt2 = torch.jit.trace(model, inputs) torch.jit.save(traced_gpt2, "traced_model.pt") except RuntimeError: self.fail("Couldn't save module.")