Use a random temp dir for writing file in tests.

This commit is contained in:
Aymeric Augustin 2019-12-20 20:56:58 +01:00
parent 12726f8556
commit 478e456e83

View File

@ -18,7 +18,7 @@ from __future__ import print_function
import copy
import sys
import os
import os.path
import shutil
import tempfile
import json
@ -222,16 +222,18 @@ class CommonTestCases:
except RuntimeError:
self.fail("Couldn't trace module.")
try:
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
self.fail("Couldn't save module.")
with TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
loaded_model = torch.jit.load("traced_model.pt")
os.remove("traced_model.pt")
except ValueError:
self.fail("Couldn't load module.")
try:
torch.jit.save(traced_gpt2, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()