mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
- Create the output directory (whose name is passed by the user in the "save_directory" parameter) where it will be saved encoder and decoder, if not exists.
- Empty the output directory, if it contains any files or subdirectories. - Create the "encoder" directory inside "save_directory", if not exists. - Create the "decoder" directory inside "save_directory", if not exists. - Save the encoder and the decoder in the previous two directories, respectively.
This commit is contained in:
parent
a436574bfd
commit
3df1d2d144
@ -166,6 +166,37 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
|
|
||||||
We save the encoder' and decoder's parameters in two separate directories.
|
We save the encoder' and decoder's parameters in two separate directories.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# If the root output directory does not exist, create it
|
||||||
|
if not os.path.exists(save_directory):
|
||||||
|
os.mkdir(save_directory)
|
||||||
|
|
||||||
|
# Check whether the output directory is empty or not
|
||||||
|
sub_directories = [directory for directory in os.listdir(save_directory)
|
||||||
|
if os.path.isdir(os.path.join(save_directory, directory))]
|
||||||
|
|
||||||
|
if len(sub_directories) > 0:
|
||||||
|
if "encoder" in sub_directories and "decoder" in sub_directories:
|
||||||
|
print("WARNING: there is an older version of encoder-decoder saved in" +\
|
||||||
|
" the output directory. The default behaviour is to overwrite them.")
|
||||||
|
|
||||||
|
# Empty the output directory
|
||||||
|
for directory_to_remove in sub_directories:
|
||||||
|
# Remove all files into the subdirectory
|
||||||
|
files_to_remove = os.listdir(os.path.join(save_directory, directory_to_remove))
|
||||||
|
for file_to_remove in files_to_remove:
|
||||||
|
os.remove(os.path.join(save_directory, directory_to_remove, file_to_remove))
|
||||||
|
# Remove the subdirectory itself
|
||||||
|
os.rmdir(os.path.join(save_directory, directory_to_remove))
|
||||||
|
|
||||||
|
assert(len(os.listdir(save_directory)) == 0) # sanity check
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(save_directory, "encoder")):
|
||||||
|
os.mkdir(os.path.join(save_directory, "encoder"))
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(save_directory, "decoder")):
|
||||||
|
os.mkdir(os.path.join(save_directory, "decoder"))
|
||||||
|
|
||||||
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
||||||
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user