diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 13fb9418224..ce4cdab8b86 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,6 @@ import warnings from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager -from dataclasses import dataclass from enum import Enum from functools import partial, wraps from threading import Thread @@ -41,7 +40,6 @@ from huggingface_hub import split_torch_state_dict_into_shards from packaging import version from torch import Tensor, nn from torch.distributions import constraints -from torch.nn import CrossEntropyLoss, Identity from torch.utils.checkpoint import checkpoint from transformers.utils import is_torchao_available @@ -50,7 +48,6 @@ from transformers.utils import is_torchao_available if is_torchao_available(): from torchao.quantization import Int4WeightOnlyConfig -from .activations import get_activation from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig @@ -98,7 +95,6 @@ from .utils import ( WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ContextManagers, - ModelOutput, PushToHubMixin, cached_file, check_torch_load_is_safe, @@ -123,7 +119,6 @@ from .utils import ( is_torch_xla_available, is_torch_xpu_available, logging, - replace_return_docstrings, strtobool, ) from .utils.generic import GeneralInterface @@ -5624,453 +5619,6 @@ if PreTrainedModel.push_to_hub.__doc__ is not None: ) -class PoolerStartLogits(nn.Module): - """ - Compute SQuAD start logits from sequence hidden states. - - Args: - config ([`PretrainedConfig`]): - The config used by the model, will be used to grab the `hidden_size` of the model. - """ - - def __init__(self, config: PretrainedConfig): - super().__init__() - self.dense = nn.Linear(config.hidden_size, 1) - logger.warning_once( - "[DEPRECATION WARNING] `PoolerStartLogits` is deprecated and will be removed in v4.53. " - "Please use model-specific class, e.g. `XLMPoolerStartLogits`." - ) - - def forward( - self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): - The final hidden states of the model. - p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token - should be masked. - - Returns: - `torch.FloatTensor`: The start logits for SQuAD. - """ - x = self.dense(hidden_states).squeeze(-1) - - if p_mask is not None: - if get_parameter_dtype(self) == torch.float16: - x = x * (1 - p_mask) - 65500 * p_mask - else: - x = x * (1 - p_mask) - 1e30 * p_mask - - return x - - -class PoolerEndLogits(nn.Module): - """ - Compute SQuAD end logits from sequence hidden states. - - Args: - config ([`PretrainedConfig`]): - The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` - to use. - """ - - def __init__(self, config: PretrainedConfig): - super().__init__() - self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) - self.activation = nn.Tanh() - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dense_1 = nn.Linear(config.hidden_size, 1) - logger.warning_once( - "[DEPRECATION WARNING] `PoolerEndLogits` is deprecated and will be removed in v4.53. " - "Please use model-specific class, e.g. `XLMPoolerEndLogits`." - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - start_states: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - p_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): - The final hidden states of the model. - start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): - The hidden states of the first tokens for the labeled span. - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - The position of the first token for the labeled span. - p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token - should be masked. - - - - One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides - `start_states`. - - - - Returns: - `torch.FloatTensor`: The end logits for SQuAD. - """ - assert start_states is not None or start_positions is not None, ( - "One of start_states, start_positions should be not None" - ) - if start_positions is not None: - slen, hsz = hidden_states.shape[-2:] - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) - start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) - start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) - - x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) - x = self.activation(x) - x = self.LayerNorm(x) - x = self.dense_1(x).squeeze(-1) - - if p_mask is not None: - if get_parameter_dtype(self) == torch.float16: - x = x * (1 - p_mask) - 65500 * p_mask - else: - x = x * (1 - p_mask) - 1e30 * p_mask - - return x - - -class PoolerAnswerClass(nn.Module): - """ - Compute SQuAD 2.0 answer class from classification and start tokens hidden states. - - Args: - config ([`PretrainedConfig`]): - The config used by the model, will be used to grab the `hidden_size` of the model. - """ - - def __init__(self, config): - super().__init__() - self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) - self.activation = nn.Tanh() - self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) - logger.warning_once( - "[DEPRECATION WARNING] `PoolerAnswerClass` is deprecated and will be removed in v4.53. " - "Please use model-specific class, e.g. `XLMPoolerAnswerClass`." - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - start_states: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - cls_index: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): - The final hidden states of the model. - start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): - The hidden states of the first tokens for the labeled span. - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - The position of the first token for the labeled span. - cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Position of the CLS token for each sentence in the batch. If `None`, takes the last token. - - - - One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides - `start_states`. - - - - Returns: - `torch.FloatTensor`: The SQuAD 2.0 answer class. - """ - # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. - hsz = hidden_states.shape[-1] - assert start_states is not None or start_positions is not None, ( - "One of start_states, start_positions should be not None" - ) - if start_positions is not None: - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) - start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) - - if cls_index is not None: - cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) - cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) - else: - cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) - - x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) - x = self.activation(x) - x = self.dense_1(x).squeeze(-1) - - return x - - -@dataclass -class SquadHeadOutput(ModelOutput): - """ - Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`]. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): - Classification loss as the sum of start token, end token (and is_impossible if provided) classification - losses. - start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): - Log probabilities for the top config.start_n_top start token possibilities (beam-search). - start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): - Indices for the top config.start_n_top start token possibilities (beam-search). - end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): - Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities - (beam-search). - end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): - Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). - cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): - Log probabilities for the `is_impossible` label of the answers. - - """ - - loss: Optional[torch.FloatTensor] = None - start_top_log_probs: Optional[torch.FloatTensor] = None - start_top_index: Optional[torch.LongTensor] = None - end_top_log_probs: Optional[torch.FloatTensor] = None - end_top_index: Optional[torch.LongTensor] = None - cls_logits: Optional[torch.FloatTensor] = None - - def __post_init__(self): - logger.warning_once( - "[DEPRECATION WARNING] `SquadHeadOutput` is deprecated and will be removed in v4.53. " - "Please use model-specific class, e.g. `XLMSquadHeadOutput`." - ) - - -class SQuADHead(nn.Module): - r""" - A SQuAD head inspired by XLNet. - - Args: - config ([`PretrainedConfig`]): - The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` - to use. - """ - - def __init__(self, config): - super().__init__() - self.start_n_top = config.start_n_top - self.end_n_top = config.end_n_top - - self.start_logits = PoolerStartLogits(config) - self.end_logits = PoolerEndLogits(config) - self.answer_class = PoolerAnswerClass(config) - - logger.warning_once( - "[DEPRECATION WARNING] `SQuADHead` is deprecated and will be removed in v4.53. " - "Please use model-specific class, e.g. `XLMSQuADHead`." - ) - - @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) - def forward( - self, - hidden_states: torch.FloatTensor, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - cls_index: Optional[torch.LongTensor] = None, - is_impossible: Optional[torch.LongTensor] = None, - p_mask: Optional[torch.FloatTensor] = None, - return_dict: bool = False, - ) -> Union[SquadHeadOutput, tuple[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): - Final hidden states of the model on the sequence tokens. - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Positions of the first token for the labeled span. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Positions of the last token for the labeled span. - cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Position of the CLS token for each sentence in the batch. If `None`, takes the last token. - is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Whether the question has a possible answer in the paragraph or not. - p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token - should be masked. - return_dict (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - - Returns: - """ - start_logits = self.start_logits(hidden_states, p_mask=p_mask) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, let's remove the dimension added by batch splitting - for x in (start_positions, end_positions, cls_index, is_impossible): - if x is not None and x.dim() > 1: - x.squeeze_(-1) - - # during training, compute the end logits based on the ground truth of the start position - end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) - - loss_fct = CrossEntropyLoss() - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if cls_index is not None and is_impossible is not None: - # Predict answerability from the representation of CLS and START - cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) - loss_fct_cls = nn.BCEWithLogitsLoss() - cls_loss = loss_fct_cls(cls_logits, is_impossible) - - # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss - total_loss += cls_loss * 0.5 - - return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) - - else: - # during inference, compute the end logits based on beam search - bsz, slen, hsz = hidden_states.size() - start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) - - start_top_log_probs, start_top_index = torch.topk( - start_log_probs, self.start_n_top, dim=-1 - ) # shape (bsz, start_n_top) - start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) - start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) - start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) - - hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( - start_states - ) # shape (bsz, slen, start_n_top, hsz) - p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None - end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) - end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) - - end_top_log_probs, end_top_index = torch.topk( - end_log_probs, self.end_n_top, dim=1 - ) # shape (bsz, end_n_top, start_n_top) - end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) - end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) - - start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) - cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) - - if not return_dict: - return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) - else: - return SquadHeadOutput( - start_top_log_probs=start_top_log_probs, - start_top_index=start_top_index, - end_top_log_probs=end_top_log_probs, - end_top_index=end_top_index, - cls_logits=cls_logits, - ) - - -class SequenceSummary(nn.Module): - r""" - Compute a single vector summary of a sequence hidden states. - - Args: - config ([`PretrainedConfig`]): - The config used by the model. Relevant arguments in the config class of the model are (refer to the actual - config class of your model for the default values it uses): - - - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: - - - `"last"` -- Take the last token hidden state (like XLNet) - - `"first"` -- Take the first token hidden state (like Bert) - - `"mean"` -- Take the mean of all tokens hidden states - - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) - - `"attn"` -- Not implemented now, use multi-head attention - - - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. - - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes - (otherwise to `config.hidden_size`). - - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, - another string or `None` will add no activation. - - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. - - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. - """ - - def __init__(self, config: PretrainedConfig): - super().__init__() - - self.summary_type = getattr(config, "summary_type", "last") - if self.summary_type == "attn": - # We should use a standard multi-head attention module with absolute positional embedding for that. - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 - # We can probably just use the multi-head attention module of PyTorch >=1.1.0 - raise NotImplementedError - - self.summary = Identity() - if hasattr(config, "summary_use_proj") and config.summary_use_proj: - if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: - num_classes = config.num_labels - else: - num_classes = config.hidden_size - self.summary = nn.Linear(config.hidden_size, num_classes) - - activation_string = getattr(config, "summary_activation", None) - self.activation: Callable = get_activation(activation_string) if activation_string else Identity() - - self.first_dropout = Identity() - if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: - self.first_dropout = nn.Dropout(config.summary_first_dropout) - - self.last_dropout = Identity() - if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: - self.last_dropout = nn.Dropout(config.summary_last_dropout) - - logger.warning_once( - "[DEPRECATION WARNING] `SequenceSummary` is deprecated and will be removed in v4.53. " - "Please use model-specific class, e.g. `XLMSequenceSummary`." - ) - - def forward( - self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None - ) -> torch.FloatTensor: - """ - Compute a single vector summary of a sequence hidden states. - - Args: - hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): - The hidden states of the last layer. - cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): - Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. - - Returns: - `torch.FloatTensor`: The summary of the sequence hidden states. - """ - if self.summary_type == "last": - output = hidden_states[:, -1] - elif self.summary_type == "first": - output = hidden_states[:, 0] - elif self.summary_type == "mean": - output = hidden_states.mean(dim=1) - elif self.summary_type == "cls_index": - if cls_index is None: - cls_index = torch.full_like( - hidden_states[..., :1, :], - hidden_states.shape[-2] - 1, - dtype=torch.long, - ) - else: - cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) - cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) - # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states - output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) - elif self.summary_type == "attn": - raise NotImplementedError - - output = self.first_dropout(output) - output = self.summary(output) - output = self.activation(output) - output = self.last_dropout(output) - - return output - - def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training).