diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 9ca3a3d0905..bb2b82b41c8 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -48,6 +48,17 @@ class PretrainedConfig(object): self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.torchscript = kwargs.pop('torchscript', False) + def save_pretrained(self, save_directory): + """ Save a configuration file to a directory, so that it + can be re-loaded using the `from_pretrained(save_directory)` class method. + """ + assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): """ @@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module): # Only save the model it-self if we are using distributed training model_to_save = self.module if hasattr(self, 'module') else self + # Save configuration file + model_to_save.config.save_pretrained(save_directory) + # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - output_config_file = os.path.join(save_directory, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):