config.architectures

This commit is contained in:
Julien Chaumond 2020-01-30 16:45:52 -05:00 committed by Lysandre Debut
parent f9bc3f5771
commit b85c59f997
2 changed files with 4 additions and 0 deletions

View File

@ -82,6 +82,7 @@ class PretrainedConfig(object):
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})

View File

@ -284,6 +284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, "module") else self
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# Save configuration file
model_to_save.config.save_pretrained(save_directory)