This commit is contained in:
qubvel 2025-07-01 18:31:07 +00:00
parent 1283877571
commit 03a09c1801

View File

@ -18,7 +18,7 @@ import copy
import json
import os
import warnings
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from packaging import version
@ -39,6 +39,10 @@ from .utils import (
from .utils.generic import is_timm_config_dict
if TYPE_CHECKING:
import torch
logger = logging.get_logger(__name__)
@ -104,8 +108,9 @@ class PretrainedConfig(PushToHubMixin):
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
Whether the model is used as an encoder/decoder or not.
is_decoder (`bool`, *optional*, defaults to `False`):
Whether to only use the decoder in an encoder-decoder architecture, otherwise it has no effect on decoder-only or encoder-only architectures.
cross_attention_hidden_size** (`bool`, *optional*):
Whether to only use the decoder in an encoder-decoder architecture, otherwise it has no effect on
decoder-only or encoder-only architectures.
cross_attention_hidden_size (`bool`, *optional*):
The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
add_cross_attention (`bool`, *optional*, defaults to `False`):
@ -135,7 +140,8 @@ class PretrainedConfig(PushToHubMixin):
or PyTorch) checkpoint.
id2label (`dict[int, str]`, *optional*):
A map from index (for instance prediction index, or target index) to label.
label2id (`dict[str, int]`, *optional*): A map from label to index for the model.
label2id (`dict[str, int]`, *optional*):
A map from label to index for the model.
num_labels (`int`, *optional*):
Number of labels to use in the last layer added to the model, typically for a classification task.
task_specific_params (`dict[str, Any]`, *optional*):
@ -151,12 +157,16 @@ class PretrainedConfig(PushToHubMixin):
model by default).
prefix (`str`, *optional*):
A specific prompt that should be added at the beginning of each text before calling the model.
bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
pad_token_id (`int`, *optional*): The id of the _padding_ token.
eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
bos_token_id (`int`, *optional*):
The id of the _beginning-of-stream_ token.
pad_token_id (`int`, *optional*):
The id of the _padding_ token.
eos_token_id (`int`, *optional*):
The id of the _end-of-stream_ token.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
sep_token_id (`int`, *optional*): The id of the _separation_ token.
sep_token_id (`int`, *optional*):
The id of the _separation_ token.
> PyTorch specific parameters
@ -175,18 +185,6 @@ class PretrainedConfig(PushToHubMixin):
This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
> TensorFlow specific parameters
use_bfloat16 (`bool`, *optional*, defaults to `False`):
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
loss_type (`str`, *optional*):
The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
be automatically inferred from the model architecture.
"""
model_type: str = ""
@ -208,93 +206,117 @@ class PretrainedConfig(PushToHubMixin):
key = super().__getattribute__("attribute_map")[key]
return super().__getattribute__(key)
def __init__(self, **kwargs):
# Attributes with defaults
self.return_dict = kwargs.pop("return_dict", True)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self._output_attentions = kwargs.pop("output_attentions", False)
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
def __init__(
self,
*,
# All models common arguments
output_hidden_states: bool = False,
output_attentions: bool = False,
return_dict: bool = True,
torchscript: bool = False,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
# Common arguments
pruned_heads: Optional[dict[int, list[int]]] = None,
tie_word_embeddings: bool = True,
chunk_size_feed_forward: int = 0,
is_encoder_decoder: bool = False,
is_decoder: bool = False,
cross_attention_hidden_size: Optional[int] = None,
add_cross_attention: bool = False,
tie_encoder_decoder: bool = False,
# Fine-tuning task arguments
architectures: Optional[list[str]] = None,
finetuning_task: Optional[str] = None,
id2label: Optional[dict[int, str]] = None,
label2id: Optional[dict[str, int]] = None,
num_labels: Optional[int] = None,
task_specific_params: Optional[dict[str, Any]] = None,
problem_type: Optional[str] = None,
# Tokenizer kwargs
tokenizer_class: Optional[str] = None,
prefix: Optional[str] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
sep_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
**kwargs,
):
# Validation for some arguments
if label2id is not None and not isinstance(label2id, dict):
raise ValueError("Argument label2id should be a dictionary.")
if id2label is not None and not isinstance(id2label, dict):
raise ValueError("Argument id2label should be a dictionary.")
if num_labels is not None and id2label is not None and len(id2label) != num_labels:
logger.warning(
f"You passed `num_labels={num_labels}` which is incompatible to "
f"the `id2label` map of length `{len(id2label)}`."
)
if problem_type is not None and problem_type not in (
"regression",
"single_label_classification",
"multi_label_classification",
):
raise ValueError(
f"The config parameter `problem_type` was not understood: received {problem_type} "
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
)
if torch_dtype is not None and isinstance(torch_dtype, str) and is_torch_available():
# we will start using self.torch_dtype in v5, but to be consistent with
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
import torch
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
torch_dtype = getattr(torch, torch_dtype)
# Attributes common for all models
self.return_dict = return_dict
self.output_hidden_states = output_hidden_states
self.torchscript = torchscript
self.torch_dtype = torch_dtype
self._output_attentions = output_attentions # has public property
# Less common kwargs, only used by some models
self.pruned_heads = pruned_heads if pruned_heads is not None else {}
self.tie_word_embeddings = tie_word_embeddings
self.chunk_size_feed_forward = chunk_size_feed_forward
# Encoder-decoder models attributes
self.is_encoder_decoder = is_encoder_decoder
self.is_decoder = is_decoder # used in encoder-decoder models to differentiate encoder from decoder
self.cross_attention_hidden_size = cross_attention_hidden_size
self.add_cross_attention = add_cross_attention
self.tie_encoder_decoder = tie_encoder_decoder
# Fine-tuning task attributes
self.architectures = architectures
self.finetuning_task = finetuning_task
self.id2label = id2label
self.label2id = label2id
self.task_specific_params = task_specific_params
self.problem_type = problem_type
if self.id2label is None:
self._create_id_label_maps(num_labels if num_labels is not None else 2)
else:
# Keys are always strings in JSON so convert ids to int here.
self.id2label = {int(key): value for key, value in self.id2label.items()}
# Tokenizer attributes
self.tokenizer_class = tokenizer_class
self.prefix = prefix
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.sep_token_id = sep_token_id
self.decoder_start_token_id = decoder_start_token_id
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
for parameter_name, default_value in self._get_global_generation_defaults().items():
setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.id2label = kwargs.pop("id2label", None)
self.label2id = kwargs.pop("label2id", None)
if self.label2id is not None and not isinstance(self.label2id, dict):
raise ValueError("Argument label2id should be a dictionary.")
if self.id2label is not None:
if not isinstance(self.id2label, dict):
raise ValueError("Argument id2label should be a dictionary.")
num_labels = kwargs.pop("num_labels", None)
if num_labels is not None and len(self.id2label) != num_labels:
logger.warning(
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{self.id2label}. The number of labels will be overwritten to {self.num_labels}."
)
self.id2label = {int(key): value for key, value in self.id2label.items()}
# Keys are always strings in JSON so convert ids to int here.
else:
self.num_labels = kwargs.pop("num_labels", 2)
if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
# we will start using self.torch_dtype in v5, but to be consistent with
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
if is_torch_available():
import torch
self.torch_dtype = getattr(torch, self.torch_dtype)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.sep_token_id = kwargs.pop("sep_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)
# regression / multi-label classification
self.problem_type = kwargs.pop("problem_type", None)
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
raise ValueError(
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
)
# TPU arguments
if kwargs.pop("xla_device", None) is not None:
logger.warning(
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
"safely remove it from your `config.json` file."
)
# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)
# Attention implementation to use, if relevant.
@ -320,9 +342,17 @@ class PretrainedConfig(PushToHubMixin):
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# TODO: remove later, deprecated arguments for TF models
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False)
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
def _create_id_label_maps(self, num_labels: int):
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
@property
def name_or_path(self) -> str:
return getattr(self, "_name_or_path", None)
return self._name_or_path
@name_or_path.setter
def name_or_path(self, value):
@ -361,9 +391,10 @@ class PretrainedConfig(PushToHubMixin):
@num_labels.setter
def num_labels(self, num_labels: int):
if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
# we do not store `num_labels` attribute in config, but instead
# compute it based on the length of the `id2label` map
if self.id2label is None or self.num_labels != num_labels:
self._create_id_label_maps(num_labels)
@property
def _attn_implementation(self):