mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
cleanup
This commit is contained in:
parent
1283877571
commit
03a09c1801
@ -18,7 +18,7 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -39,6 +39,10 @@ from .utils import (
|
||||
from .utils.generic import is_timm_config_dict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -104,8 +108,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is used as an encoder/decoder or not.
|
||||
is_decoder (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only use the decoder in an encoder-decoder architecture, otherwise it has no effect on decoder-only or encoder-only architectures.
|
||||
cross_attention_hidden_size** (`bool`, *optional*):
|
||||
Whether to only use the decoder in an encoder-decoder architecture, otherwise it has no effect on
|
||||
decoder-only or encoder-only architectures.
|
||||
cross_attention_hidden_size (`bool`, *optional*):
|
||||
The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
|
||||
setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
|
||||
add_cross_attention (`bool`, *optional*, defaults to `False`):
|
||||
@ -135,7 +140,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
or PyTorch) checkpoint.
|
||||
id2label (`dict[int, str]`, *optional*):
|
||||
A map from index (for instance prediction index, or target index) to label.
|
||||
label2id (`dict[str, int]`, *optional*): A map from label to index for the model.
|
||||
label2id (`dict[str, int]`, *optional*):
|
||||
A map from label to index for the model.
|
||||
num_labels (`int`, *optional*):
|
||||
Number of labels to use in the last layer added to the model, typically for a classification task.
|
||||
task_specific_params (`dict[str, Any]`, *optional*):
|
||||
@ -151,12 +157,16 @@ class PretrainedConfig(PushToHubMixin):
|
||||
model by default).
|
||||
prefix (`str`, *optional*):
|
||||
A specific prompt that should be added at the beginning of each text before calling the model.
|
||||
bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
|
||||
pad_token_id (`int`, *optional*): The id of the _padding_ token.
|
||||
eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
|
||||
bos_token_id (`int`, *optional*):
|
||||
The id of the _beginning-of-stream_ token.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the _padding_ token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the _end-of-stream_ token.
|
||||
decoder_start_token_id (`int`, *optional*):
|
||||
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
|
||||
sep_token_id (`int`, *optional*): The id of the _separation_ token.
|
||||
sep_token_id (`int`, *optional*):
|
||||
The id of the _separation_ token.
|
||||
|
||||
> PyTorch specific parameters
|
||||
|
||||
@ -175,18 +185,6 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
This attribute is currently not being used during model loading time, but this may change in the future
|
||||
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
|
||||
|
||||
> TensorFlow specific parameters
|
||||
|
||||
use_bfloat16 (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
|
||||
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
|
||||
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
|
||||
v5.
|
||||
loss_type (`str`, *optional*):
|
||||
The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
|
||||
be automatically inferred from the model architecture.
|
||||
"""
|
||||
|
||||
model_type: str = ""
|
||||
@ -208,93 +206,117 @@ class PretrainedConfig(PushToHubMixin):
|
||||
key = super().__getattribute__("attribute_map")[key]
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Attributes with defaults
|
||||
self.return_dict = kwargs.pop("return_dict", True)
|
||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||
self._output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
|
||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||
self.tie_word_embeddings = kwargs.pop(
|
||||
"tie_word_embeddings", True
|
||||
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# All models common arguments
|
||||
output_hidden_states: bool = False,
|
||||
output_attentions: bool = False,
|
||||
return_dict: bool = True,
|
||||
torchscript: bool = False,
|
||||
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
||||
# Common arguments
|
||||
pruned_heads: Optional[dict[int, list[int]]] = None,
|
||||
tie_word_embeddings: bool = True,
|
||||
chunk_size_feed_forward: int = 0,
|
||||
is_encoder_decoder: bool = False,
|
||||
is_decoder: bool = False,
|
||||
cross_attention_hidden_size: Optional[int] = None,
|
||||
add_cross_attention: bool = False,
|
||||
tie_encoder_decoder: bool = False,
|
||||
# Fine-tuning task arguments
|
||||
architectures: Optional[list[str]] = None,
|
||||
finetuning_task: Optional[str] = None,
|
||||
id2label: Optional[dict[int, str]] = None,
|
||||
label2id: Optional[dict[str, int]] = None,
|
||||
num_labels: Optional[int] = None,
|
||||
task_specific_params: Optional[dict[str, Any]] = None,
|
||||
problem_type: Optional[str] = None,
|
||||
# Tokenizer kwargs
|
||||
tokenizer_class: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
sep_token_id: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Validation for some arguments
|
||||
if label2id is not None and not isinstance(label2id, dict):
|
||||
raise ValueError("Argument label2id should be a dictionary.")
|
||||
if id2label is not None and not isinstance(id2label, dict):
|
||||
raise ValueError("Argument id2label should be a dictionary.")
|
||||
if num_labels is not None and id2label is not None and len(id2label) != num_labels:
|
||||
logger.warning(
|
||||
f"You passed `num_labels={num_labels}` which is incompatible to "
|
||||
f"the `id2label` map of length `{len(id2label)}`."
|
||||
)
|
||||
if problem_type is not None and problem_type not in (
|
||||
"regression",
|
||||
"single_label_classification",
|
||||
"multi_label_classification",
|
||||
):
|
||||
raise ValueError(
|
||||
f"The config parameter `problem_type` was not understood: received {problem_type} "
|
||||
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
||||
)
|
||||
if torch_dtype is not None and isinstance(torch_dtype, str) and is_torch_available():
|
||||
# we will start using self.torch_dtype in v5, but to be consistent with
|
||||
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
|
||||
import torch
|
||||
|
||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||
self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
|
||||
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
||||
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
|
||||
# Attributes common for all models
|
||||
self.return_dict = return_dict
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.torchscript = torchscript
|
||||
self.torch_dtype = torch_dtype
|
||||
self._output_attentions = output_attentions # has public property
|
||||
|
||||
# Less common kwargs, only used by some models
|
||||
self.pruned_heads = pruned_heads if pruned_heads is not None else {}
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.chunk_size_feed_forward = chunk_size_feed_forward
|
||||
|
||||
# Encoder-decoder models attributes
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
self.is_decoder = is_decoder # used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.cross_attention_hidden_size = cross_attention_hidden_size
|
||||
self.add_cross_attention = add_cross_attention
|
||||
self.tie_encoder_decoder = tie_encoder_decoder
|
||||
|
||||
# Fine-tuning task attributes
|
||||
self.architectures = architectures
|
||||
self.finetuning_task = finetuning_task
|
||||
self.id2label = id2label
|
||||
self.label2id = label2id
|
||||
self.task_specific_params = task_specific_params
|
||||
self.problem_type = problem_type
|
||||
|
||||
if self.id2label is None:
|
||||
self._create_id_label_maps(num_labels if num_labels is not None else 2)
|
||||
else:
|
||||
# Keys are always strings in JSON so convert ids to int here.
|
||||
self.id2label = {int(key): value for key, value in self.id2label.items()}
|
||||
|
||||
# Tokenizer attributes
|
||||
self.tokenizer_class = tokenizer_class
|
||||
self.prefix = prefix
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.sep_token_id = sep_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
|
||||
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
|
||||
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
|
||||
for parameter_name, default_value in self._get_global_generation_defaults().items():
|
||||
setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
||||
self.id2label = kwargs.pop("id2label", None)
|
||||
self.label2id = kwargs.pop("label2id", None)
|
||||
if self.label2id is not None and not isinstance(self.label2id, dict):
|
||||
raise ValueError("Argument label2id should be a dictionary.")
|
||||
if self.id2label is not None:
|
||||
if not isinstance(self.id2label, dict):
|
||||
raise ValueError("Argument id2label should be a dictionary.")
|
||||
num_labels = kwargs.pop("num_labels", None)
|
||||
if num_labels is not None and len(self.id2label) != num_labels:
|
||||
logger.warning(
|
||||
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
|
||||
f"{self.id2label}. The number of labels will be overwritten to {self.num_labels}."
|
||||
)
|
||||
self.id2label = {int(key): value for key, value in self.id2label.items()}
|
||||
# Keys are always strings in JSON so convert ids to int here.
|
||||
else:
|
||||
self.num_labels = kwargs.pop("num_labels", 2)
|
||||
|
||||
if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
|
||||
# we will start using self.torch_dtype in v5, but to be consistent with
|
||||
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
self.torch_dtype = getattr(torch, self.torch_dtype)
|
||||
|
||||
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
||||
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
||||
self.prefix = kwargs.pop("prefix", None)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
||||
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
||||
|
||||
# task specific arguments
|
||||
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
||||
|
||||
# regression / multi-label classification
|
||||
self.problem_type = kwargs.pop("problem_type", None)
|
||||
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
|
||||
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
|
||||
raise ValueError(
|
||||
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
|
||||
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
||||
)
|
||||
|
||||
# TPU arguments
|
||||
if kwargs.pop("xla_device", None) is not None:
|
||||
logger.warning(
|
||||
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
|
||||
"safely remove it from your `config.json` file."
|
||||
)
|
||||
|
||||
# Name or path to the pretrained checkpoint
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
# Config hash
|
||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
# Attention implementation to use, if relevant.
|
||||
@ -320,9 +342,17 @@ class PretrainedConfig(PushToHubMixin):
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
# TODO: remove later, deprecated arguments for TF models
|
||||
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False)
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
|
||||
def _create_id_label_maps(self, num_labels: int):
|
||||
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return getattr(self, "_name_or_path", None)
|
||||
return self._name_or_path
|
||||
|
||||
@name_or_path.setter
|
||||
def name_or_path(self, value):
|
||||
@ -361,9 +391,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
@num_labels.setter
|
||||
def num_labels(self, num_labels: int):
|
||||
if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
|
||||
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
# we do not store `num_labels` attribute in config, but instead
|
||||
# compute it based on the length of the `id2label` map
|
||||
if self.id2label is None or self.num_labels != num_labels:
|
||||
self._create_id_label_maps(num_labels)
|
||||
|
||||
@property
|
||||
def _attn_implementation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user