mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Use a random temp dir for writing file in tests.
This commit is contained in:
parent
12726f8556
commit
478e456e83
@ -18,7 +18,7 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
@ -222,16 +222,18 @@ class CommonTestCases:
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_gpt2, "traced_model.pt")
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't save module.")
|
||||
with TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load("traced_model.pt")
|
||||
os.remove("traced_model.pt")
|
||||
except ValueError:
|
||||
self.fail("Couldn't load module.")
|
||||
try:
|
||||
torch.jit.save(traced_gpt2, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user