mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
- Create the output directory (whose name is passed by the user in the "save_directory" parameter) where it will be saved encoder and decoder, if not exists.
- Empty the output directory, if it contains any files or subdirectories. - Create the "encoder" directory inside "save_directory", if not exists. - Create the "decoder" directory inside "save_directory", if not exists. - Save the encoder and the decoder in the previous two directories, respectively.
This commit is contained in:
parent
a436574bfd
commit
3df1d2d144
@ -166,6 +166,37 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
|
||||
We save the encoder' and decoder's parameters in two separate directories.
|
||||
"""
|
||||
|
||||
# If the root output directory does not exist, create it
|
||||
if not os.path.exists(save_directory):
|
||||
os.mkdir(save_directory)
|
||||
|
||||
# Check whether the output directory is empty or not
|
||||
sub_directories = [directory for directory in os.listdir(save_directory)
|
||||
if os.path.isdir(os.path.join(save_directory, directory))]
|
||||
|
||||
if len(sub_directories) > 0:
|
||||
if "encoder" in sub_directories and "decoder" in sub_directories:
|
||||
print("WARNING: there is an older version of encoder-decoder saved in" +\
|
||||
" the output directory. The default behaviour is to overwrite them.")
|
||||
|
||||
# Empty the output directory
|
||||
for directory_to_remove in sub_directories:
|
||||
# Remove all files into the subdirectory
|
||||
files_to_remove = os.listdir(os.path.join(save_directory, directory_to_remove))
|
||||
for file_to_remove in files_to_remove:
|
||||
os.remove(os.path.join(save_directory, directory_to_remove, file_to_remove))
|
||||
# Remove the subdirectory itself
|
||||
os.rmdir(os.path.join(save_directory, directory_to_remove))
|
||||
|
||||
assert(len(os.listdir(save_directory)) == 0) # sanity check
|
||||
|
||||
if not os.path.exists(os.path.join(save_directory, "encoder")):
|
||||
os.mkdir(os.path.join(save_directory, "encoder"))
|
||||
|
||||
if not os.path.exists(os.path.join(save_directory, "decoder")):
|
||||
os.mkdir(os.path.join(save_directory, "decoder"))
|
||||
|
||||
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
||||
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user