diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index c84162117a0..c03d307e715 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -218,12 +218,11 @@ class CommonTestCases: inputs = inputs_dict['input_ids'] # Let's keep only input_ids try: - torch.jit.trace(model, inputs) + traced_gpt2 = torch.jit.trace(model, inputs) except RuntimeError: self.fail("Couldn't trace module.") try: - traced_gpt2 = torch.jit.trace(model, inputs) torch.jit.save(traced_gpt2, "traced_model.pt") except RuntimeError: self.fail("Couldn't save module.")