diff --git a/docs/source/en/internal/modeling_utils.md b/docs/source/en/internal/modeling_utils.md index 35b8b2e88eb..fe6f961da94 100644 --- a/docs/source/en/internal/modeling_utils.md +++ b/docs/source/en/internal/modeling_utils.md @@ -33,23 +33,6 @@ Most of those are only useful if you are studying the code of the models in the [[autodoc]] pytorch_utils.Conv1D -[[autodoc]] modeling_utils.PoolerStartLogits - - forward - -[[autodoc]] modeling_utils.PoolerEndLogits - - forward - -[[autodoc]] modeling_utils.PoolerAnswerClass - - forward - -[[autodoc]] modeling_utils.SquadHeadOutput - -[[autodoc]] modeling_utils.SQuADHead - - forward - -[[autodoc]] modeling_utils.SequenceSummary - - forward - ## PyTorch Helper Functions [[autodoc]] pytorch_utils.apply_chunking_to_forward diff --git a/docs/source/ja/internal/modeling_utils.md b/docs/source/ja/internal/modeling_utils.md index 62aa2040c8a..6e8335623a0 100644 --- a/docs/source/ja/internal/modeling_utils.md +++ b/docs/source/ja/internal/modeling_utils.md @@ -25,23 +25,6 @@ rendered properly in your Markdown viewer. [[autodoc]] pytorch_utils.Conv1D -[[autodoc]] modeling_utils.PoolerStartLogits - - forward - -[[autodoc]] modeling_utils.PoolerEndLogits - - forward - -[[autodoc]] modeling_utils.PoolerAnswerClass - - forward - -[[autodoc]] modeling_utils.SquadHeadOutput - -[[autodoc]] modeling_utils.SQuADHead - - forward - -[[autodoc]] modeling_utils.SequenceSummary - - forward - ## PyTorch Helper Functions [[autodoc]] pytorch_utils.apply_chunking_to_forward diff --git a/docs/source/ko/internal/modeling_utils.md b/docs/source/ko/internal/modeling_utils.md index 76cc09db292..f84ae30cd6f 100644 --- a/docs/source/ko/internal/modeling_utils.md +++ b/docs/source/ko/internal/modeling_utils.md @@ -25,23 +25,6 @@ rendered properly in your Markdown viewer. [[autodoc]] pytorch_utils.Conv1D -[[autodoc]] modeling_utils.PoolerStartLogits - - forward - -[[autodoc]] modeling_utils.PoolerEndLogits - - forward - -[[autodoc]] modeling_utils.PoolerAnswerClass - - forward - -[[autodoc]] modeling_utils.SquadHeadOutput - -[[autodoc]] modeling_utils.SQuADHead - - forward - -[[autodoc]] modeling_utils.SequenceSummary - - forward - ## PyTorch 헬퍼(helper) 함수 [[transformers.apply_chunking_to_forward]] [[autodoc]] pytorch_utils.apply_chunking_to_forward diff --git a/docs/source/zh/internal/modeling_utils.md b/docs/source/zh/internal/modeling_utils.md index 93341b323e8..2cc62711c71 100644 --- a/docs/source/zh/internal/modeling_utils.md +++ b/docs/source/zh/internal/modeling_utils.md @@ -25,23 +25,6 @@ rendered properly in your Markdown viewer. [[autodoc]] pytorch_utils.Conv1D -[[autodoc]] modeling_utils.PoolerStartLogits - - forward - -[[autodoc]] modeling_utils.PoolerEndLogits - - forward - -[[autodoc]] modeling_utils.PoolerAnswerClass - - forward - -[[autodoc]] modeling_utils.SquadHeadOutput - -[[autodoc]] modeling_utils.SQuADHead - - forward - -[[autodoc]] modeling_utils.SequenceSummary - - forward - ## PyTorch帮助函数 [[autodoc]] pytorch_utils.apply_chunking_to_forward diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3be400b93d9..50a200ae76b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5384,6 +5384,10 @@ class PoolerStartLogits(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, 1) + logger.warning_once( + "[DEPRECATION WARNING] `PoolerStartLogits` is deprecated and will be removed in v4.53. " + "Please use model-specific class, e.g. `XLMPoolerStartLogits`." + ) def forward( self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None @@ -5426,6 +5430,10 @@ class PoolerEndLogits(nn.Module): self.activation = nn.Tanh() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense_1 = nn.Linear(config.hidden_size, 1) + logger.warning_once( + "[DEPRECATION WARNING] `PoolerEndLogits` is deprecated and will be removed in v4.53. " + "Please use model-specific class, e.g. `XLMPoolerEndLogits`." + ) def forward( self, @@ -5493,6 +5501,10 @@ class PoolerAnswerClass(nn.Module): self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) self.activation = nn.Tanh() self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + logger.warning_once( + "[DEPRECATION WARNING] `PoolerAnswerClass` is deprecated and will be removed in v4.53. " + "Please use model-specific class, e.g. `XLMPoolerAnswerClass`." + ) def forward( self, @@ -5574,6 +5586,12 @@ class SquadHeadOutput(ModelOutput): end_top_index: Optional[torch.LongTensor] = None cls_logits: Optional[torch.FloatTensor] = None + def __post_init__(self): + logger.warning_once( + "[DEPRECATION WARNING] `SquadHeadOutput` is deprecated and will be removed in v4.53. " + "Please use model-specific class, e.g. `XLMSquadHeadOutput`." + ) + class SQuADHead(nn.Module): r""" @@ -5594,6 +5612,11 @@ class SQuADHead(nn.Module): self.end_logits = PoolerEndLogits(config) self.answer_class = PoolerAnswerClass(config) + logger.warning_once( + "[DEPRECATION WARNING] `SQuADHead` is deprecated and will be removed in v4.53. " + "Please use model-specific class, e.g. `XLMSQuADHead`." + ) + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) def forward( self, @@ -5747,6 +5770,11 @@ class SequenceSummary(nn.Module): if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout) + logger.warning_once( + "[DEPRECATION WARNING] `SequenceSummary` is deprecated and will be removed in v4.53. " + "Please use model-specific class, e.g. `XLMSequenceSummary`." + ) + def forward( self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None ) -> torch.FloatTensor: diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 320cf17126c..bafaa36dd67 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -18,14 +18,14 @@ import copy import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss -from ...activations import ACT2FN +from ...activations import ACT2FN, get_activation from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -34,7 +34,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, isin_mps_friendly from ...utils import ( ModelOutput, @@ -499,6 +499,106 @@ class ClvpEncoderLayer(nn.Module): return outputs +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp +class ClvpSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`ClvpConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: ClvpConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP class ClvpDecoderMLP(nn.Module): def __init__(self, intermediate_size, config): @@ -884,7 +984,7 @@ class ClvpEncoder(ClvpPreTrainedModel): self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = ClvpSequenceSummary(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 13efb4839a8..486e678b596 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -17,7 +17,7 @@ import math import os from operator import attrgetter -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -33,7 +33,7 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_convbert import ConvBertConfig @@ -683,6 +683,106 @@ class ConvBertPredictionHeadTransform(nn.Module): return hidden_states +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->ConvBert +class ConvBertSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`ConvBertConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: ConvBertConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + CONVBERT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and @@ -1077,7 +1177,7 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel): super().__init__(config) self.convbert = ConvBertModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = ConvBertSequenceSummary(config) self.classifier = nn.Linear(config.hidden_size, 1) # Initialize weights and apply final processing diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 7b73f022122..08cc3e530d6 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -17,7 +17,7 @@ import math import os from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -36,7 +36,7 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -946,6 +946,106 @@ class ElectraClassificationHead(nn.Module): return x +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra +class ElectraSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`ElectraConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: ElectraConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + @add_start_docstrings( """ ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the @@ -1442,7 +1542,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): super().__init__(config) self.electra = ElectraModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = ElectraSequenceSummary(config) self.classifier = nn.Linear(config.hidden_size, 1) # Initialize weights and apply final processing diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index bc1d66f8355..b5a5ea793ae 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -17,14 +17,14 @@ import itertools import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import numpy as np import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import gelu +from ...activations import gelu, get_activation from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -34,7 +34,7 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -329,6 +329,431 @@ class FlaubertPredLayer(nn.Module): return outputs +@dataclass +# Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert +class FlaubertSquadHeadOutput(ModelOutput): + """ + Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert +class FlaubertPoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert +class FlaubertPoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert +class FlaubertPoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert +class FlaubertSQuADHead(nn.Module): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = FlaubertPoolerStartLogits(config) + self.end_logits = FlaubertPoolerEndLogits(config) + self.answer_class = FlaubertPoolerAnswerClass(config) + + @replace_return_docstrings(output_type=FlaubertSquadHeadOutput, config_class=FlaubertConfig) + def forward( + self, + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_dict: bool = False, + ) -> Union[FlaubertSquadHeadOutput, Tuple[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return FlaubertSquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert +class FlaubertSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + # Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert class FlaubertPreTrainedModel(PreTrainedModel): """ @@ -744,7 +1169,7 @@ class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin): """, FLAUBERT_START_DOCSTRING, ) -# Copied transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +# Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert class FlaubertForSequenceClassification(FlaubertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -752,7 +1177,7 @@ class FlaubertForSequenceClassification(FlaubertPreTrainedModel): self.config = config self.transformer = FlaubertModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = FlaubertSequenceSummary(config) # Initialize weights and apply final processing self.post_init() @@ -1081,13 +1506,13 @@ class FlaubertForQuestionAnsweringOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +# Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert class FlaubertForQuestionAnswering(FlaubertPreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = FlaubertModel(config) - self.qa_outputs = SQuADHead(config) + self.qa_outputs = FlaubertSQuADHead(config) # Initialize weights and apply final processing self.post_init() @@ -1137,11 +1562,11 @@ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel): Example: ```python - >>> from transformers import XLMTokenizer, XLMForQuestionAnswering + >>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering >>> import torch - >>> tokenizer = XLMTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048") - >>> model = XLMForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048") + >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048") + >>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048") >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze( ... 0 @@ -1203,13 +1628,13 @@ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel): """, FLAUBERT_START_DOCSTRING, ) -# Copied from transformer.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +# Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert class FlaubertForMultipleChoice(FlaubertPreTrainedModel): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.transformer = FlaubertModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = FlaubertSequenceSummary(config) self.logits_proj = nn.Linear(config.num_labels, 1) # Initialize weights and apply final processing diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 75c148f233a..717962b3377 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -26,7 +26,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import ACT2FN +from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( @@ -36,7 +36,7 @@ from ...modeling_outputs import ( SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, @@ -450,6 +450,106 @@ class GPT2Block(nn.Module): return outputs # hidden_states, present, (attentions, cross_attentions) +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2 +class GPT2SequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`GPT2Config`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: GPT2Config): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + class GPT2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1138,7 +1238,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): config.num_labels = 1 self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.multiple_choice_head = SequenceSummary(config) + self.multiple_choice_head = GPT2SequenceSummary(config) # Model parallel self.model_parallel = False diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 595d3e93735..244ad6d50a2 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -19,16 +19,16 @@ import json import math import os from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import gelu_new, silu +from ...activations import gelu_new, get_activation, silu from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, @@ -262,6 +262,106 @@ class Block(nn.Module): return outputs +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT +class OpenAIGPTSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`OpenAIGPTConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: OpenAIGPTConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + class OpenAIGPTPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -628,7 +728,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): config.num_labels = 1 self.transformer = OpenAIGPTModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.multiple_choice_head = SequenceSummary(config) + self.multiple_choice_head = OpenAIGPTSequenceSummary(config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 7814472bb48..86465153b10 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -16,7 +16,7 @@ import math import os -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -24,7 +24,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import ACT2FN +from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -35,7 +35,7 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, @@ -620,6 +620,106 @@ class RoFormerEncoder(nn.Module): ) +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->RoFormer +class RoFormerSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`RoFormerConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: RoFormerConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + class RoFormerPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() @@ -1292,7 +1392,7 @@ class RoFormerForMultipleChoice(RoFormerPreTrainedModel): super().__init__(config) self.roformer = RoFormerModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = RoFormerSequenceSummary(config) self.classifier = nn.Linear(config.hidden_size, 1) # Initialize weights and apply final processing diff --git a/src/transformers/models/univnet/modeling_univnet.py b/src/transformers/models/univnet/modeling_univnet.py index 3c73625592f..d105711b5dc 100644 --- a/src/transformers/models/univnet/modeling_univnet.py +++ b/src/transformers/models/univnet/modeling_univnet.py @@ -20,7 +20,8 @@ import torch import torch.utils.checkpoint from torch import nn -from ...modeling_utils import ModelOutput, PreTrainedModel +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_univnet import UnivNetConfig diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 0ffa3319081..16f1d4ec3ff 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -19,14 +19,14 @@ PyTorch XLM model. import itertools import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import numpy as np import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import gelu +from ...activations import gelu, get_activation from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -36,7 +36,7 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -88,6 +88,425 @@ def get_masks(slen, lengths, causal, padding_mask=None): return mask, attn_mask +@dataclass +class XLMSquadHeadOutput(ModelOutput): + """ + Base class for outputs of question answering models using a [`~modeling_utils.XLMSQuADHead`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + + +class XLMPoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`XLMConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: XLMConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class XLMPoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`XLMConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: XLMConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class XLMPoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`XLMConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: XLMConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +class XLMSQuADHead(nn.Module): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`XLMConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: XLMConfig): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = XLMPoolerStartLogits(config) + self.end_logits = XLMPoolerEndLogits(config) + self.answer_class = XLMPoolerAnswerClass(config) + + @replace_return_docstrings(output_type=XLMSquadHeadOutput, config_class=XLMConfig) + def forward( + self, + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_dict: bool = False, + ) -> Union[XLMSquadHeadOutput, Tuple[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return XLMSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return XLMSquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +class XLMSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`XLMConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: XLMConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + class MultiHeadAttention(nn.Module): NEW_ID = itertools.count() @@ -252,7 +671,7 @@ class XLMPreTrainedModel(PreTrainedModel): @dataclass class XLMForQuestionAnsweringOutput(ModelOutput): """ - Base class for outputs of question answering models using a `SquadHead`. + Base class for outputs of question answering models using a `XLMSQuADHead`. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): @@ -767,7 +1186,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): self.config = config self.transformer = XLMModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = XLMSequenceSummary(config) # Initialize weights and apply final processing self.post_init() @@ -971,7 +1390,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): super().__init__(config) self.transformer = XLMModel(config) - self.qa_outputs = SQuADHead(config) + self.qa_outputs = XLMSQuADHead(config) # Initialize weights and apply final processing self.post_init() @@ -1176,7 +1595,7 @@ class XLMForMultipleChoice(XLMPreTrainedModel): super().__init__(config, *inputs, **kwargs) self.transformer = XLMModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = XLMSequenceSummary(config) self.logits_proj = nn.Linear(config.num_labels, 1) # Initialize weights and apply final processing diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index db0703f949c..de7446e57bb 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -19,15 +19,15 @@ PyTorch XLNet model. import warnings from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import ACT2FN +from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin -from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( ModelOutput, @@ -529,6 +529,281 @@ class XLNetLayer(nn.Module): return output_x +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->XLNet +class XLNetPoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`XLNetConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: XLNetConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->XLNet +class XLNetPoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`XLNetConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: XLNetConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->XLNet +class XLNetPoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`XLNetConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: XLNetConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->XLNet +class XLNetSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`XLNetConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: XLNetConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + class XLNetPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1502,7 +1777,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): self.config = config self.transformer = XLNetModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = XLNetSequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, config.num_labels) # Initialize weights and apply final processing @@ -1696,7 +1971,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): super().__init__(config) self.transformer = XLNetModel(config) - self.sequence_summary = SequenceSummary(config) + self.sequence_summary = XLNetSequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, 1) # Initialize weights and apply final processing @@ -1911,9 +2186,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): self.end_n_top = config.end_n_top self.transformer = XLNetModel(config) - self.start_logits = PoolerStartLogits(config) - self.end_logits = PoolerEndLogits(config) - self.answer_class = PoolerAnswerClass(config) + self.start_logits = XLNetPoolerStartLogits(config) + self.end_logits = XLNetPoolerEndLogits(config) + self.answer_class = XLNetPoolerAnswerClass(config) # Initialize weights and apply final processing self.post_init() diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index ca677f6549b..9865067bcdb 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -327,17 +327,6 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s is not None ): attribute_used = True - # `SequenceSummary` is called with `SequenceSummary(config)` - elif attribute in [ - "summary_type", - "summary_use_proj", - "summary_activation", - "summary_last_dropout", - "summary_proj_to_labels", - "summary_first_dropout", - ]: - if "SequenceSummary" in modeling_source: - attribute_used = True if attribute_used: break if attribute_used: