Remove redundant torch.jit.trace in tests.

This looks like it could be expensive, so don't run it twice.
This commit is contained in:
Aymeric Augustin 2019-12-20 20:56:58 +01:00
parent ac1b449cc9
commit 12726f8556

View File

@ -218,12 +218,11 @@ class CommonTestCases:
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
try:
torch.jit.trace(model, inputs)
traced_gpt2 = torch.jit.trace(model, inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
try:
traced_gpt2 = torch.jit.trace(model, inputs)
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
self.fail("Couldn't save module.")