diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 62fed9ef045..510ac05ec68 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -40,12 +40,17 @@ class PretrainedConfig(object): - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values. - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. - Parameters: - ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. - ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) - ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. - ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. - ``torchscript``: string, default `False`. Is the model used with Torchscript. + Args: + finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): + Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. + num_labels (:obj:`int`, `optional`, defaults to `2`): + Number of classes to use when the model is a classification model (sequences/tokens) + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): + Should the model returns attentions weights. + output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`): + Should the model returns all hidden-states. + torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): + Is the model used with Torchscript (for PyTorch models). """ pretrained_config_archive_map = {} # type: Dict[str, str] model_type = "" # type: str @@ -93,8 +98,13 @@ class PretrainedConfig(object): raise err def save_pretrained(self, save_directory): - """ Save a configuration object to the directory `save_directory`, so that it - can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. + """ + Save a configuration object to the directory `save_directory`, so that it + can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. + + Args: + save_directory (:obj:`string`): + Directory where the configuration JSON file will be saved. """ assert os.path.isdir( save_directory @@ -107,40 +117,45 @@ class PretrainedConfig(object): logger.info("Configuration saved in {}".format(output_config_file)) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> 'PretrainedConfig': + r""" - Parameters: - pretrained_model_name_or_path: either: + Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. - - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. - - cache_dir: (`optional`) string: + Args: + pretrained_model_name_or_path (:obj:`string`): + either: + - a string with the `shortcut name` of a pre-trained model configuration to load from cache or + download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to + our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a configuration file saved using the + :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. + - a path or url to a saved configuration JSON `file`, e.g.: + ``./my_model_directory/configuration.json``. + cache_dir (:obj:`string`, `optional`): Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. - - kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. - - - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. - - force_download: (`optional`) boolean, default False: - Force to (re-)download the model weights and configuration files and override the cached versions if they exists. - - resume_download: (`optional`) boolean, default False: + kwargs (:obj:`Dict[str, any]`, `optional`): + The values in kwargs of any keys which are configuration attributes will be used to override the loaded + values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is + controlled by the `return_unused_kwargs` keyword parameter. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Force to (re-)download the model weights and configuration files and override the cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. - - proxies: (`optional`) dict, default None: - A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + proxies (:obj:`Dict`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g.: + :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - return_unused_kwargs: (`optional`) bool: + If False, then this function returns just the final configuration object. + If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a + dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part + of kwargs which has not been used to update `config` and is otherwise ignored. - - If False, then this function returns just the final configuration object. - - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. + Returns: + :class:`PretrainedConfig`: An instance of a configuration object Examples:: @@ -169,9 +184,14 @@ class PretrainedConfig(object): for instantiating a Config using `from_dict`. Parameters: - pretrained_config_archive_map: (`optional`) Dict: - A map of `shortcut names` to `url`. - By default, will use the current class attribute. + pretrained_model_name_or_path (:obj:`string`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict: + A map of `shortcut names` to `url`. By default, will use the current class attribute. + + Returns: + :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. + """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -235,8 +255,21 @@ class PretrainedConfig(object): return config_dict, kwargs @classmethod - def from_dict(cls, config_dict: Dict, **kwargs): - """Constructs a `Config` from a Python dictionary of parameters.""" + def from_dict(cls, config_dict: Dict, **kwargs) -> 'PretrainedConfig': + """ + Constructs a `Config` from a Python dictionary of parameters. + + Args: + config_dict (:obj:`Dict[str, any]`): + Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved + from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` + method. + kwargs (:obj:`Dict[str, any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + :class:`PretrainedConfig`: An instance of a configuration object + """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) config = cls(**config_dict) @@ -260,8 +293,18 @@ class PretrainedConfig(object): return config @classmethod - def from_json_file(cls, json_file: str): - """Constructs a `Config` from the path to a json file of parameters.""" + def from_json_file(cls, json_file: str) -> 'PretrainedConfig': + """ + Constructs a `Config` from the path to a json file of parameters. + + Args: + json_file (:obj:`string`): + Path to the JSON file containing the parameters. + + Returns: + :class:`PretrainedConfig`: An instance of a configuration object + + """ config_dict = cls._dict_from_json_file(json_file) return cls(**config_dict) @@ -278,17 +321,33 @@ class PretrainedConfig(object): return "{} {}".format(self.__class__.__name__, self.to_json_string()) def to_dict(self): - """Serializes this instance to a Python dictionary.""" + """ + Serializes this instance to a Python dictionary. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ output = copy.deepcopy(self.__dict__) if hasattr(self.__class__, "model_type"): output["model_type"] = self.__class__.model_type return output def to_json_string(self): - """Serializes this instance to a JSON string.""" + """ + Serializes this instance to a JSON string. + + Returns: + :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. + """ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path): - """ Save this instance to a json file.""" + """ + Save this instance to a json file. + + Args: + json_file_path (:obj:`string`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string())