- Create the output directory (whose name is passed by the user in the "save_directory" parameter) where it will be saved encoder and decoder, if not exists.

- Empty the output directory, if it contains any files or subdirectories.
- Create the "encoder" directory inside "save_directory", if not exists.
- Create the "decoder" directory inside "save_directory", if not exists.
- Save the encoder and the decoder in the previous two directories, respectively.
This commit is contained in:
Francesco 2019-12-17 10:19:54 +01:00 committed by Lysandre Debut
parent a436574bfd
commit 3df1d2d144

View File

@ -166,6 +166,37 @@ class PreTrainedEncoderDecoder(nn.Module):
We save the encoder' and decoder's parameters in two separate directories.
"""
# If the root output directory does not exist, create it
if not os.path.exists(save_directory):
os.mkdir(save_directory)
# Check whether the output directory is empty or not
sub_directories = [directory for directory in os.listdir(save_directory)
if os.path.isdir(os.path.join(save_directory, directory))]
if len(sub_directories) > 0:
if "encoder" in sub_directories and "decoder" in sub_directories:
print("WARNING: there is an older version of encoder-decoder saved in" +\
" the output directory. The default behaviour is to overwrite them.")
# Empty the output directory
for directory_to_remove in sub_directories:
# Remove all files into the subdirectory
files_to_remove = os.listdir(os.path.join(save_directory, directory_to_remove))
for file_to_remove in files_to_remove:
os.remove(os.path.join(save_directory, directory_to_remove, file_to_remove))
# Remove the subdirectory itself
os.rmdir(os.path.join(save_directory, directory_to_remove))
assert(len(os.listdir(save_directory)) == 0) # sanity check
if not os.path.exists(os.path.join(save_directory, "encoder")):
os.mkdir(os.path.join(save_directory, "encoder"))
if not os.path.exists(os.path.join(save_directory, "decoder")):
os.mkdir(os.path.join(save_directory, "decoder"))
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))