mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Deprecate modeling_utils.py classes (#37298)
* Move utils classes into models * Add deprecation warnings * Remove from docs * Update config attributes check
This commit is contained in:
parent
a245011252
commit
4f58fc9c82
@ -33,23 +33,6 @@ Most of those are only useful if you are studying the code of the models in the
|
|||||||
|
|
||||||
[[autodoc]] pytorch_utils.Conv1D
|
[[autodoc]] pytorch_utils.Conv1D
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerStartLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerEndLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerAnswerClass
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SquadHeadOutput
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SQuADHead
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SequenceSummary
|
|
||||||
- forward
|
|
||||||
|
|
||||||
## PyTorch Helper Functions
|
## PyTorch Helper Functions
|
||||||
|
|
||||||
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
||||||
|
@ -25,23 +25,6 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
[[autodoc]] pytorch_utils.Conv1D
|
[[autodoc]] pytorch_utils.Conv1D
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerStartLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerEndLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerAnswerClass
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SquadHeadOutput
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SQuADHead
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SequenceSummary
|
|
||||||
- forward
|
|
||||||
|
|
||||||
## PyTorch Helper Functions
|
## PyTorch Helper Functions
|
||||||
|
|
||||||
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
||||||
|
@ -25,23 +25,6 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
[[autodoc]] pytorch_utils.Conv1D
|
[[autodoc]] pytorch_utils.Conv1D
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerStartLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerEndLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerAnswerClass
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SquadHeadOutput
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SQuADHead
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SequenceSummary
|
|
||||||
- forward
|
|
||||||
|
|
||||||
## PyTorch 헬퍼(helper) 함수 [[transformers.apply_chunking_to_forward]]
|
## PyTorch 헬퍼(helper) 함수 [[transformers.apply_chunking_to_forward]]
|
||||||
|
|
||||||
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
||||||
|
@ -25,23 +25,6 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
[[autodoc]] pytorch_utils.Conv1D
|
[[autodoc]] pytorch_utils.Conv1D
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerStartLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerEndLogits
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerAnswerClass
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SquadHeadOutput
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SQuADHead
|
|
||||||
- forward
|
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.SequenceSummary
|
|
||||||
- forward
|
|
||||||
|
|
||||||
## PyTorch帮助函数
|
## PyTorch帮助函数
|
||||||
|
|
||||||
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
||||||
|
@ -5384,6 +5384,10 @@ class PoolerStartLogits(nn.Module):
|
|||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(self, config: PretrainedConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, 1)
|
self.dense = nn.Linear(config.hidden_size, 1)
|
||||||
|
logger.warning_once(
|
||||||
|
"[DEPRECATION WARNING] `PoolerStartLogits` is deprecated and will be removed in v4.53. "
|
||||||
|
"Please use model-specific class, e.g. `XLMPoolerStartLogits`."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
||||||
@ -5426,6 +5430,10 @@ class PoolerEndLogits(nn.Module):
|
|||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||||
|
logger.warning_once(
|
||||||
|
"[DEPRECATION WARNING] `PoolerEndLogits` is deprecated and will be removed in v4.53. "
|
||||||
|
"Please use model-specific class, e.g. `XLMPoolerEndLogits`."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -5493,6 +5501,10 @@ class PoolerAnswerClass(nn.Module):
|
|||||||
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
logger.warning_once(
|
||||||
|
"[DEPRECATION WARNING] `PoolerAnswerClass` is deprecated and will be removed in v4.53. "
|
||||||
|
"Please use model-specific class, e.g. `XLMPoolerAnswerClass`."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -5574,6 +5586,12 @@ class SquadHeadOutput(ModelOutput):
|
|||||||
end_top_index: Optional[torch.LongTensor] = None
|
end_top_index: Optional[torch.LongTensor] = None
|
||||||
cls_logits: Optional[torch.FloatTensor] = None
|
cls_logits: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
logger.warning_once(
|
||||||
|
"[DEPRECATION WARNING] `SquadHeadOutput` is deprecated and will be removed in v4.53. "
|
||||||
|
"Please use model-specific class, e.g. `XLMSquadHeadOutput`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SQuADHead(nn.Module):
|
class SQuADHead(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -5594,6 +5612,11 @@ class SQuADHead(nn.Module):
|
|||||||
self.end_logits = PoolerEndLogits(config)
|
self.end_logits = PoolerEndLogits(config)
|
||||||
self.answer_class = PoolerAnswerClass(config)
|
self.answer_class = PoolerAnswerClass(config)
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
"[DEPRECATION WARNING] `SQuADHead` is deprecated and will be removed in v4.53. "
|
||||||
|
"Please use model-specific class, e.g. `XLMSQuADHead`."
|
||||||
|
)
|
||||||
|
|
||||||
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
|
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -5747,6 +5770,11 @@ class SequenceSummary(nn.Module):
|
|||||||
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
"[DEPRECATION WARNING] `SequenceSummary` is deprecated and will be removed in v4.53. "
|
||||||
|
"Please use model-specific class, e.g. `XLMSequenceSummary`."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
@ -18,14 +18,14 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN, get_activation
|
||||||
from ...generation import GenerationConfig, GenerationMixin
|
from ...generation import GenerationConfig, GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@ -34,7 +34,7 @@ from ...modeling_outputs import (
|
|||||||
BaseModelOutputWithPooling,
|
BaseModelOutputWithPooling,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import Conv1D, isin_mps_friendly
|
from ...pytorch_utils import Conv1D, isin_mps_friendly
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -499,6 +499,106 @@ class ClvpEncoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp
|
||||||
|
class ClvpSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`ClvpConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ClvpConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
|
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
|
||||||
class ClvpDecoderMLP(nn.Module):
|
class ClvpDecoderMLP(nn.Module):
|
||||||
def __init__(self, intermediate_size, config):
|
def __init__(self, intermediate_size, config):
|
||||||
@ -884,7 +984,7 @@ class ClvpEncoder(ClvpPreTrainedModel):
|
|||||||
self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
|
self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
|
||||||
self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = ClvpSequenceSummary(config)
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -33,7 +33,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_convbert import ConvBertConfig
|
from .configuration_convbert import ConvBertConfig
|
||||||
@ -683,6 +683,106 @@ class ConvBertPredictionHeadTransform(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->ConvBert
|
||||||
|
class ConvBertSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`ConvBertConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ConvBertConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
CONVBERT_START_DOCSTRING = r"""
|
CONVBERT_START_DOCSTRING = r"""
|
||||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||||
@ -1077,7 +1177,7 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.convbert = ConvBertModel(config)
|
self.convbert = ConvBertModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = ConvBertSequenceSummary(config)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -36,7 +36,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -946,6 +946,106 @@ class ElectraClassificationHead(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra
|
||||||
|
class ElectraSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`ElectraConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ElectraConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
||||||
@ -1442,7 +1542,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.electra = ElectraModel(config)
|
self.electra = ElectraModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = ElectraSequenceSummary(config)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
@ -17,14 +17,14 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import gelu
|
from ...activations import gelu, get_activation
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
@ -34,7 +34,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -329,6 +329,431 @@ class FlaubertPredLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert
|
||||||
|
class FlaubertSquadHeadOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
||||||
|
Classification loss as the sum of start token, end token (and is_impossible if provided) classification
|
||||||
|
losses.
|
||||||
|
start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||||
|
start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||||
|
end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
|
||||||
|
(beam-search).
|
||||||
|
end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
|
||||||
|
cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Log probabilities for the `is_impossible` label of the answers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
start_top_log_probs: Optional[torch.FloatTensor] = None
|
||||||
|
start_top_index: Optional[torch.LongTensor] = None
|
||||||
|
end_top_log_probs: Optional[torch.FloatTensor] = None
|
||||||
|
end_top_index: Optional[torch.LongTensor] = None
|
||||||
|
cls_logits: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert
|
||||||
|
class FlaubertPoolerStartLogits(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD start logits from sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`FlaubertConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: FlaubertConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The start logits for SQuAD.
|
||||||
|
"""
|
||||||
|
x = self.dense(hidden_states).squeeze(-1)
|
||||||
|
|
||||||
|
if p_mask is not None:
|
||||||
|
if p_mask.dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert
|
||||||
|
class FlaubertPoolerEndLogits(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD end logits from sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`FlaubertConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||||
|
to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: FlaubertConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_states: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
p_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||||
|
The hidden states of the first tokens for the labeled span.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
The position of the first token for the labeled span.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||||
|
`start_states`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The end logits for SQuAD.
|
||||||
|
"""
|
||||||
|
assert start_states is not None or start_positions is not None, (
|
||||||
|
"One of start_states, start_positions should be not None"
|
||||||
|
)
|
||||||
|
if start_positions is not None:
|
||||||
|
slen, hsz = hidden_states.shape[-2:]
|
||||||
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||||
|
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||||
|
|
||||||
|
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.LayerNorm(x)
|
||||||
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
|
if p_mask is not None:
|
||||||
|
if p_mask.dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert
|
||||||
|
class FlaubertPoolerAnswerClass(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`FlaubertConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: FlaubertConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_states: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
cls_index: Optional[torch.LongTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||||
|
The hidden states of the first tokens for the labeled span.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
The position of the first token for the labeled span.
|
||||||
|
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||||
|
`start_states`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The SQuAD 2.0 answer class.
|
||||||
|
"""
|
||||||
|
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
||||||
|
hsz = hidden_states.shape[-1]
|
||||||
|
assert start_states is not None or start_positions is not None, (
|
||||||
|
"One of start_states, start_positions should be not None"
|
||||||
|
)
|
||||||
|
if start_positions is not None:
|
||||||
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
||||||
|
|
||||||
|
if cls_index is not None:
|
||||||
|
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
||||||
|
else:
|
||||||
|
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
||||||
|
|
||||||
|
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert
|
||||||
|
class FlaubertSQuADHead(nn.Module):
|
||||||
|
r"""
|
||||||
|
A SQuAD head inspired by XLNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`FlaubertConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||||
|
to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: FlaubertConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.start_n_top = config.start_n_top
|
||||||
|
self.end_n_top = config.end_n_top
|
||||||
|
|
||||||
|
self.start_logits = FlaubertPoolerStartLogits(config)
|
||||||
|
self.end_logits = FlaubertPoolerEndLogits(config)
|
||||||
|
self.answer_class = FlaubertPoolerAnswerClass(config)
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=FlaubertSquadHeadOutput, config_class=FlaubertConfig)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
cls_index: Optional[torch.LongTensor] = None,
|
||||||
|
is_impossible: Optional[torch.LongTensor] = None,
|
||||||
|
p_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
return_dict: bool = False,
|
||||||
|
) -> Union[FlaubertSquadHeadOutput, Tuple[torch.FloatTensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
Final hidden states of the model on the sequence tokens.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Positions of the first token for the labeled span.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Positions of the last token for the labeled span.
|
||||||
|
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||||
|
is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Whether the question has a possible answer in the paragraph or not.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||||
|
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
||||||
|
for x in (start_positions, end_positions, cls_index, is_impossible):
|
||||||
|
if x is not None and x.dim() > 1:
|
||||||
|
x.squeeze_(-1)
|
||||||
|
|
||||||
|
# during training, compute the end logits based on the ground truth of the start position
|
||||||
|
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if cls_index is not None and is_impossible is not None:
|
||||||
|
# Predict answerability from the representation of CLS and START
|
||||||
|
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
||||||
|
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||||
|
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||||
|
|
||||||
|
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||||
|
total_loss += cls_loss * 0.5
|
||||||
|
|
||||||
|
return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# during inference, compute the end logits based on beam search
|
||||||
|
bsz, slen, hsz = hidden_states.size()
|
||||||
|
start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||||
|
|
||||||
|
start_top_log_probs, start_top_index = torch.topk(
|
||||||
|
start_log_probs, self.start_n_top, dim=-1
|
||||||
|
) # shape (bsz, start_n_top)
|
||||||
|
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||||
|
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||||
|
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
|
||||||
|
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
|
||||||
|
start_states
|
||||||
|
) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
||||||
|
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
||||||
|
end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
||||||
|
|
||||||
|
end_top_log_probs, end_top_index = torch.topk(
|
||||||
|
end_log_probs, self.end_n_top, dim=1
|
||||||
|
) # shape (bsz, end_n_top, start_n_top)
|
||||||
|
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||||
|
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||||
|
|
||||||
|
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||||
|
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
|
||||||
|
else:
|
||||||
|
return FlaubertSquadHeadOutput(
|
||||||
|
start_top_log_probs=start_top_log_probs,
|
||||||
|
start_top_index=start_top_index,
|
||||||
|
end_top_log_probs=end_top_log_probs,
|
||||||
|
end_top_index=end_top_index,
|
||||||
|
cls_logits=cls_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert
|
||||||
|
class FlaubertSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`FlaubertConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: FlaubertConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert
|
||||||
class FlaubertPreTrainedModel(PreTrainedModel):
|
class FlaubertPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
@ -744,7 +1169,7 @@ class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
|
|||||||
""",
|
""",
|
||||||
FLAUBERT_START_DOCSTRING,
|
FLAUBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
# Copied transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
# Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||||
class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
|
class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@ -752,7 +1177,7 @@ class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.transformer = FlaubertModel(config)
|
self.transformer = FlaubertModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = FlaubertSequenceSummary(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -1081,13 +1506,13 @@ class FlaubertForQuestionAnsweringOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
# Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||||
class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
|
class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.transformer = FlaubertModel(config)
|
self.transformer = FlaubertModel(config)
|
||||||
self.qa_outputs = SQuADHead(config)
|
self.qa_outputs = FlaubertSQuADHead(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -1137,11 +1562,11 @@ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import XLMTokenizer, XLMForQuestionAnswering
|
>>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering
|
||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = XLMTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
|
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
|
||||||
>>> model = XLMForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
|
>>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
|
||||||
|
|
||||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
|
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
|
||||||
... 0
|
... 0
|
||||||
@ -1203,13 +1628,13 @@ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
|
|||||||
""",
|
""",
|
||||||
FLAUBERT_START_DOCSTRING,
|
FLAUBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
# Copied from transformer.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
# Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||||
class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
|
class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.transformer = FlaubertModel(config)
|
self.transformer = FlaubertModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = FlaubertSequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.num_labels, 1)
|
self.logits_proj = nn.Linear(config.num_labels, 1)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
@ -26,7 +26,7 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN, get_activation
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@ -36,7 +36,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -450,6 +450,106 @@ class GPT2Block(nn.Module):
|
|||||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
|
||||||
|
class GPT2SequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`GPT2Config`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GPT2Config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class GPT2PreTrainedModel(PreTrainedModel):
|
class GPT2PreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -1138,7 +1238,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
|
|||||||
config.num_labels = 1
|
config.num_labels = 1
|
||||||
self.transformer = GPT2Model(config)
|
self.transformer = GPT2Model(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
self.multiple_choice_head = SequenceSummary(config)
|
self.multiple_choice_head = GPT2SequenceSummary(config)
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
|
@ -19,16 +19,16 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import gelu_new, silu
|
from ...activations import gelu_new, get_activation, silu
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -262,6 +262,106 @@ class Block(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT
|
||||||
|
class OpenAIGPTSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`OpenAIGPTConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenAIGPTConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -628,7 +728,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
config.num_labels = 1
|
config.num_labels = 1
|
||||||
self.transformer = OpenAIGPTModel(config)
|
self.transformer = OpenAIGPTModel(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
self.multiple_choice_head = SequenceSummary(config)
|
self.multiple_choice_head = OpenAIGPTSequenceSummary(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -24,7 +24,7 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN, get_activation
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -35,7 +35,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@ -620,6 +620,106 @@ class RoFormerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->RoFormer
|
||||||
|
class RoFormerSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`RoFormerConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: RoFormerConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class RoFormerPredictionHeadTransform(nn.Module):
|
class RoFormerPredictionHeadTransform(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1292,7 +1392,7 @@ class RoFormerForMultipleChoice(RoFormerPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.roformer = RoFormerModel(config)
|
self.roformer = RoFormerModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = RoFormerSequenceSummary(config)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
@ -20,7 +20,8 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...modeling_utils import ModelOutput, PreTrainedModel
|
from ...modeling_outputs import ModelOutput
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_univnet import UnivNetConfig
|
from .configuration_univnet import UnivNetConfig
|
||||||
|
|
||||||
|
@ -19,14 +19,14 @@ PyTorch XLM model.
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import gelu
|
from ...activations import gelu, get_activation
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
@ -36,7 +36,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -88,6 +88,425 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
return mask, attn_mask
|
return mask, attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class XLMSquadHeadOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of question answering models using a [`~modeling_utils.XLMSQuADHead`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
||||||
|
Classification loss as the sum of start token, end token (and is_impossible if provided) classification
|
||||||
|
losses.
|
||||||
|
start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||||
|
start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||||
|
end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
|
||||||
|
(beam-search).
|
||||||
|
end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
|
||||||
|
cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||||
|
Log probabilities for the `is_impossible` label of the answers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
start_top_log_probs: Optional[torch.FloatTensor] = None
|
||||||
|
start_top_index: Optional[torch.LongTensor] = None
|
||||||
|
end_top_log_probs: Optional[torch.FloatTensor] = None
|
||||||
|
end_top_index: Optional[torch.LongTensor] = None
|
||||||
|
cls_logits: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class XLMPoolerStartLogits(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD start logits from sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLMConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLMConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The start logits for SQuAD.
|
||||||
|
"""
|
||||||
|
x = self.dense(hidden_states).squeeze(-1)
|
||||||
|
|
||||||
|
if p_mask is not None:
|
||||||
|
if p_mask.dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class XLMPoolerEndLogits(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD end logits from sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLMConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||||
|
to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLMConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_states: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
p_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||||
|
The hidden states of the first tokens for the labeled span.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
The position of the first token for the labeled span.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||||
|
`start_states`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The end logits for SQuAD.
|
||||||
|
"""
|
||||||
|
assert start_states is not None or start_positions is not None, (
|
||||||
|
"One of start_states, start_positions should be not None"
|
||||||
|
)
|
||||||
|
if start_positions is not None:
|
||||||
|
slen, hsz = hidden_states.shape[-2:]
|
||||||
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||||
|
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||||
|
|
||||||
|
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.LayerNorm(x)
|
||||||
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
|
if p_mask is not None:
|
||||||
|
if p_mask.dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class XLMPoolerAnswerClass(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLMConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLMConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_states: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
cls_index: Optional[torch.LongTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||||
|
The hidden states of the first tokens for the labeled span.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
The position of the first token for the labeled span.
|
||||||
|
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||||
|
`start_states`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The SQuAD 2.0 answer class.
|
||||||
|
"""
|
||||||
|
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
||||||
|
hsz = hidden_states.shape[-1]
|
||||||
|
assert start_states is not None or start_positions is not None, (
|
||||||
|
"One of start_states, start_positions should be not None"
|
||||||
|
)
|
||||||
|
if start_positions is not None:
|
||||||
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
||||||
|
|
||||||
|
if cls_index is not None:
|
||||||
|
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
||||||
|
else:
|
||||||
|
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
||||||
|
|
||||||
|
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class XLMSQuADHead(nn.Module):
|
||||||
|
r"""
|
||||||
|
A SQuAD head inspired by XLNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLMConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||||
|
to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLMConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.start_n_top = config.start_n_top
|
||||||
|
self.end_n_top = config.end_n_top
|
||||||
|
|
||||||
|
self.start_logits = XLMPoolerStartLogits(config)
|
||||||
|
self.end_logits = XLMPoolerEndLogits(config)
|
||||||
|
self.answer_class = XLMPoolerAnswerClass(config)
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=XLMSquadHeadOutput, config_class=XLMConfig)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
cls_index: Optional[torch.LongTensor] = None,
|
||||||
|
is_impossible: Optional[torch.LongTensor] = None,
|
||||||
|
p_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
return_dict: bool = False,
|
||||||
|
) -> Union[XLMSquadHeadOutput, Tuple[torch.FloatTensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
Final hidden states of the model on the sequence tokens.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Positions of the first token for the labeled span.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Positions of the last token for the labeled span.
|
||||||
|
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||||
|
is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Whether the question has a possible answer in the paragraph or not.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||||
|
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
||||||
|
for x in (start_positions, end_positions, cls_index, is_impossible):
|
||||||
|
if x is not None and x.dim() > 1:
|
||||||
|
x.squeeze_(-1)
|
||||||
|
|
||||||
|
# during training, compute the end logits based on the ground truth of the start position
|
||||||
|
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if cls_index is not None and is_impossible is not None:
|
||||||
|
# Predict answerability from the representation of CLS and START
|
||||||
|
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
||||||
|
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||||
|
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||||
|
|
||||||
|
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||||
|
total_loss += cls_loss * 0.5
|
||||||
|
|
||||||
|
return XLMSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# during inference, compute the end logits based on beam search
|
||||||
|
bsz, slen, hsz = hidden_states.size()
|
||||||
|
start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||||
|
|
||||||
|
start_top_log_probs, start_top_index = torch.topk(
|
||||||
|
start_log_probs, self.start_n_top, dim=-1
|
||||||
|
) # shape (bsz, start_n_top)
|
||||||
|
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||||
|
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||||
|
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
|
||||||
|
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
|
||||||
|
start_states
|
||||||
|
) # shape (bsz, slen, start_n_top, hsz)
|
||||||
|
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
||||||
|
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
||||||
|
end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
||||||
|
|
||||||
|
end_top_log_probs, end_top_index = torch.topk(
|
||||||
|
end_log_probs, self.end_n_top, dim=1
|
||||||
|
) # shape (bsz, end_n_top, start_n_top)
|
||||||
|
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||||
|
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||||
|
|
||||||
|
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||||
|
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
|
||||||
|
else:
|
||||||
|
return XLMSquadHeadOutput(
|
||||||
|
start_top_log_probs=start_top_log_probs,
|
||||||
|
start_top_index=start_top_index,
|
||||||
|
end_top_log_probs=end_top_log_probs,
|
||||||
|
end_top_index=end_top_index,
|
||||||
|
cls_logits=cls_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class XLMSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLMConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLMConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
NEW_ID = itertools.count()
|
NEW_ID = itertools.count()
|
||||||
|
|
||||||
@ -252,7 +671,7 @@ class XLMPreTrainedModel(PreTrainedModel):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class XLMForQuestionAnsweringOutput(ModelOutput):
|
class XLMForQuestionAnsweringOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of question answering models using a `SquadHead`.
|
Base class for outputs of question answering models using a `XLMSQuADHead`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
||||||
@ -767,7 +1186,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = XLMSequenceSummary(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -971,7 +1390,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.qa_outputs = SQuADHead(config)
|
self.qa_outputs = XLMSQuADHead(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -1176,7 +1595,7 @@ class XLMForMultipleChoice(XLMPreTrainedModel):
|
|||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = XLMSequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.num_labels, 1)
|
self.logits_proj = nn.Linear(config.num_labels, 1)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
@ -19,15 +19,15 @@ PyTorch XLNet model.
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN, get_activation
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -529,6 +529,281 @@ class XLNetLayer(nn.Module):
|
|||||||
return output_x
|
return output_x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->XLNet
|
||||||
|
class XLNetPoolerStartLogits(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD start logits from sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLNetConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The start logits for SQuAD.
|
||||||
|
"""
|
||||||
|
x = self.dense(hidden_states).squeeze(-1)
|
||||||
|
|
||||||
|
if p_mask is not None:
|
||||||
|
if p_mask.dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->XLNet
|
||||||
|
class XLNetPoolerEndLogits(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD end logits from sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLNetConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||||
|
to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_states: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
p_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||||
|
The hidden states of the first tokens for the labeled span.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
The position of the first token for the labeled span.
|
||||||
|
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||||
|
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||||
|
should be masked.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||||
|
`start_states`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The end logits for SQuAD.
|
||||||
|
"""
|
||||||
|
assert start_states is not None or start_positions is not None, (
|
||||||
|
"One of start_states, start_positions should be not None"
|
||||||
|
)
|
||||||
|
if start_positions is not None:
|
||||||
|
slen, hsz = hidden_states.shape[-2:]
|
||||||
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||||
|
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||||
|
|
||||||
|
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.LayerNorm(x)
|
||||||
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
|
if p_mask is not None:
|
||||||
|
if p_mask.dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->XLNet
|
||||||
|
class XLNetPoolerAnswerClass(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLNetConfig`]):
|
||||||
|
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
start_states: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
cls_index: Optional[torch.LongTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||||
|
The final hidden states of the model.
|
||||||
|
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||||
|
The hidden states of the first tokens for the labeled span.
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
The position of the first token for the labeled span.
|
||||||
|
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||||
|
`start_states`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The SQuAD 2.0 answer class.
|
||||||
|
"""
|
||||||
|
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
||||||
|
hsz = hidden_states.shape[-1]
|
||||||
|
assert start_states is not None or start_positions is not None, (
|
||||||
|
"One of start_states, start_positions should be not None"
|
||||||
|
)
|
||||||
|
if start_positions is not None:
|
||||||
|
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
||||||
|
|
||||||
|
if cls_index is not None:
|
||||||
|
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||||
|
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
||||||
|
else:
|
||||||
|
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
||||||
|
|
||||||
|
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->XLNet
|
||||||
|
class XLNetSequenceSummary(nn.Module):
|
||||||
|
r"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`XLNetConfig`]):
|
||||||
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||||
|
config class of your model for the default values it uses):
|
||||||
|
|
||||||
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||||
|
|
||||||
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||||
|
- `"first"` -- Take the first token hidden state (like Bert)
|
||||||
|
- `"mean"` -- Take the mean of all tokens hidden states
|
||||||
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||||
|
- `"attn"` -- Not implemented now, use multi-head attention
|
||||||
|
|
||||||
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||||
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||||
|
(otherwise to `config.hidden_size`).
|
||||||
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||||
|
another string or `None` will add no activation.
|
||||||
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||||
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: XLNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.summary_type = getattr(config, "summary_type", "last")
|
||||||
|
if self.summary_type == "attn":
|
||||||
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.summary = nn.Identity()
|
||||||
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||||
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
|
num_classes = config.num_labels
|
||||||
|
else:
|
||||||
|
num_classes = config.hidden_size
|
||||||
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
activation_string = getattr(config, "summary_activation", None)
|
||||||
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
||||||
|
|
||||||
|
self.first_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||||
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
|
self.last_dropout = nn.Identity()
|
||||||
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||||
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||||
|
The hidden states of the last layer.
|
||||||
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||||
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||||
|
"""
|
||||||
|
if self.summary_type == "last":
|
||||||
|
output = hidden_states[:, -1]
|
||||||
|
elif self.summary_type == "first":
|
||||||
|
output = hidden_states[:, 0]
|
||||||
|
elif self.summary_type == "mean":
|
||||||
|
output = hidden_states.mean(dim=1)
|
||||||
|
elif self.summary_type == "cls_index":
|
||||||
|
if cls_index is None:
|
||||||
|
cls_index = torch.full_like(
|
||||||
|
hidden_states[..., :1, :],
|
||||||
|
hidden_states.shape[-2] - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||||
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||||
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||||
|
elif self.summary_type == "attn":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
output = self.first_dropout(output)
|
||||||
|
output = self.summary(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.last_dropout(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class XLNetPreTrainedModel(PreTrainedModel):
|
class XLNetPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -1502,7 +1777,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = XLNetSequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@ -1696,7 +1971,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = XLNetSequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.d_model, 1)
|
self.logits_proj = nn.Linear(config.d_model, 1)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@ -1911,9 +2186,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
self.end_n_top = config.end_n_top
|
self.end_n_top = config.end_n_top
|
||||||
|
|
||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.start_logits = PoolerStartLogits(config)
|
self.start_logits = XLNetPoolerStartLogits(config)
|
||||||
self.end_logits = PoolerEndLogits(config)
|
self.end_logits = XLNetPoolerEndLogits(config)
|
||||||
self.answer_class = PoolerAnswerClass(config)
|
self.answer_class = XLNetPoolerAnswerClass(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
@ -327,17 +327,6 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
|||||||
is not None
|
is not None
|
||||||
):
|
):
|
||||||
attribute_used = True
|
attribute_used = True
|
||||||
# `SequenceSummary` is called with `SequenceSummary(config)`
|
|
||||||
elif attribute in [
|
|
||||||
"summary_type",
|
|
||||||
"summary_use_proj",
|
|
||||||
"summary_activation",
|
|
||||||
"summary_last_dropout",
|
|
||||||
"summary_proj_to_labels",
|
|
||||||
"summary_first_dropout",
|
|
||||||
]:
|
|
||||||
if "SequenceSummary" in modeling_source:
|
|
||||||
attribute_used = True
|
|
||||||
if attribute_used:
|
if attribute_used:
|
||||||
break
|
break
|
||||||
if attribute_used:
|
if attribute_used:
|
||||||
|
Loading…
Reference in New Issue
Block a user