mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 04:10:06 +06:00
Remove deprecated classes in modeling_utils.py (#38919)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* remove deprecated classes * style
This commit is contained in:
parent
797860c68c
commit
0725cd6953
@ -29,7 +29,6 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
@ -41,7 +40,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from packaging import version
|
||||
from torch import Tensor, nn
|
||||
from torch.distributions import constraints
|
||||
from torch.nn import CrossEntropyLoss, Identity
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers.utils import is_torchao_available
|
||||
@ -50,7 +48,6 @@ from transformers.utils import is_torchao_available
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig
|
||||
@ -98,7 +95,6 @@ from .utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
cached_file,
|
||||
check_torch_load_is_safe,
|
||||
@ -123,7 +119,6 @@ from .utils import (
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.generic import GeneralInterface
|
||||
@ -5624,453 +5619,6 @@ if PreTrainedModel.push_to_hub.__doc__ is not None:
|
||||
)
|
||||
|
||||
|
||||
class PoolerStartLogits(nn.Module):
|
||||
"""
|
||||
Compute SQuAD start logits from sequence hidden states.
|
||||
|
||||
Args:
|
||||
config ([`PretrainedConfig`]):
|
||||
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, 1)
|
||||
logger.warning_once(
|
||||
"[DEPRECATION WARNING] `PoolerStartLogits` is deprecated and will be removed in v4.53. "
|
||||
"Please use model-specific class, e.g. `XLMPoolerStartLogits`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||
should be masked.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The start logits for SQuAD.
|
||||
"""
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
if get_parameter_dtype(self) == torch.float16:
|
||||
x = x * (1 - p_mask) - 65500 * p_mask
|
||||
else:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PoolerEndLogits(nn.Module):
|
||||
"""
|
||||
Compute SQuAD end logits from sequence hidden states.
|
||||
|
||||
Args:
|
||||
config ([`PretrainedConfig`]):
|
||||
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||
to use.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||
logger.warning_once(
|
||||
"[DEPRECATION WARNING] `PoolerEndLogits` is deprecated and will be removed in v4.53. "
|
||||
"Please use model-specific class, e.g. `XLMPoolerEndLogits`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
start_states: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
p_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||
The hidden states of the first tokens for the labeled span.
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
The position of the first token for the labeled span.
|
||||
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||
should be masked.
|
||||
|
||||
<Tip>
|
||||
|
||||
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||
`start_states`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The end logits for SQuAD.
|
||||
"""
|
||||
assert start_states is not None or start_positions is not None, (
|
||||
"One of start_states, start_positions should be not None"
|
||||
)
|
||||
if start_positions is not None:
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||
|
||||
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
||||
x = self.activation(x)
|
||||
x = self.LayerNorm(x)
|
||||
x = self.dense_1(x).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
if get_parameter_dtype(self) == torch.float16:
|
||||
x = x * (1 - p_mask) - 65500 * p_mask
|
||||
else:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PoolerAnswerClass(nn.Module):
|
||||
"""
|
||||
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
|
||||
|
||||
Args:
|
||||
config ([`PretrainedConfig`]):
|
||||
The config used by the model, will be used to grab the `hidden_size` of the model.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
logger.warning_once(
|
||||
"[DEPRECATION WARNING] `PoolerAnswerClass` is deprecated and will be removed in v4.53. "
|
||||
"Please use model-specific class, e.g. `XLMPoolerAnswerClass`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
start_states: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
cls_index: Optional[torch.LongTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
|
||||
The hidden states of the first tokens for the labeled span.
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
The position of the first token for the labeled span.
|
||||
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||
|
||||
<Tip>
|
||||
|
||||
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
|
||||
`start_states`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The SQuAD 2.0 answer class.
|
||||
"""
|
||||
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert start_states is not None or start_positions is not None, (
|
||||
"One of start_states, start_positions should be not None"
|
||||
)
|
||||
if start_positions is not None:
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
||||
|
||||
if cls_index is not None:
|
||||
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
||||
else:
|
||||
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
||||
|
||||
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
||||
x = self.activation(x)
|
||||
x = self.dense_1(x).squeeze(-1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class SquadHeadOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`].
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification
|
||||
losses.
|
||||
start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||
Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
|
||||
(beam-search).
|
||||
end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||
Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
|
||||
cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
|
||||
Log probabilities for the `is_impossible` label of the answers.
|
||||
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
start_top_log_probs: Optional[torch.FloatTensor] = None
|
||||
start_top_index: Optional[torch.LongTensor] = None
|
||||
end_top_log_probs: Optional[torch.FloatTensor] = None
|
||||
end_top_index: Optional[torch.LongTensor] = None
|
||||
cls_logits: Optional[torch.FloatTensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
logger.warning_once(
|
||||
"[DEPRECATION WARNING] `SquadHeadOutput` is deprecated and will be removed in v4.53. "
|
||||
"Please use model-specific class, e.g. `XLMSquadHeadOutput`."
|
||||
)
|
||||
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
r"""
|
||||
A SQuAD head inspired by XLNet.
|
||||
|
||||
Args:
|
||||
config ([`PretrainedConfig`]):
|
||||
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
|
||||
to use.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.start_n_top = config.start_n_top
|
||||
self.end_n_top = config.end_n_top
|
||||
|
||||
self.start_logits = PoolerStartLogits(config)
|
||||
self.end_logits = PoolerEndLogits(config)
|
||||
self.answer_class = PoolerAnswerClass(config)
|
||||
|
||||
logger.warning_once(
|
||||
"[DEPRECATION WARNING] `SQuADHead` is deprecated and will be removed in v4.53. "
|
||||
"Please use model-specific class, e.g. `XLMSQuADHead`."
|
||||
)
|
||||
|
||||
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
cls_index: Optional[torch.LongTensor] = None,
|
||||
is_impossible: Optional[torch.LongTensor] = None,
|
||||
p_mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[SquadHeadOutput, tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
Final hidden states of the model on the sequence tokens.
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Positions of the first token for the labeled span.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Positions of the last token for the labeled span.
|
||||
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
|
||||
is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
|
||||
should be masked.
|
||||
return_dict (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
||||
for x in (start_positions, end_positions, cls_index, is_impossible):
|
||||
if x is not None and x.dim() > 1:
|
||||
x.squeeze_(-1)
|
||||
|
||||
# during training, compute the end logits based on the ground truth of the start position
|
||||
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if cls_index is not None and is_impossible is not None:
|
||||
# Predict answerability from the representation of CLS and START
|
||||
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
||||
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
|
||||
return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
bsz, slen, hsz = hidden_states.size()
|
||||
start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(
|
||||
start_log_probs, self.start_n_top, dim=-1
|
||||
) # shape (bsz, start_n_top)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
|
||||
start_states
|
||||
) # shape (bsz, slen, start_n_top, hsz)
|
||||
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
||||
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
||||
end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
||||
|
||||
end_top_log_probs, end_top_index = torch.topk(
|
||||
end_log_probs, self.end_n_top, dim=1
|
||||
) # shape (bsz, end_n_top, start_n_top)
|
||||
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||
|
||||
if not return_dict:
|
||||
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
|
||||
else:
|
||||
return SquadHeadOutput(
|
||||
start_top_log_probs=start_top_log_probs,
|
||||
start_top_index=start_top_index,
|
||||
end_top_log_probs=end_top_log_probs,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits,
|
||||
)
|
||||
|
||||
|
||||
class SequenceSummary(nn.Module):
|
||||
r"""
|
||||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
config ([`PretrainedConfig`]):
|
||||
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
||||
config class of your model for the default values it uses):
|
||||
|
||||
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
||||
|
||||
- `"last"` -- Take the last token hidden state (like XLNet)
|
||||
- `"first"` -- Take the first token hidden state (like Bert)
|
||||
- `"mean"` -- Take the mean of all tokens hidden states
|
||||
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||
- `"attn"` -- Not implemented now, use multi-head attention
|
||||
|
||||
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
||||
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
||||
(otherwise to `config.hidden_size`).
|
||||
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
||||
another string or `None` will add no activation.
|
||||
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
||||
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.summary_type = getattr(config, "summary_type", "last")
|
||||
if self.summary_type == "attn":
|
||||
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||
raise NotImplementedError
|
||||
|
||||
self.summary = Identity()
|
||||
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
||||
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
||||
num_classes = config.num_labels
|
||||
else:
|
||||
num_classes = config.hidden_size
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
activation_string = getattr(config, "summary_activation", None)
|
||||
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
|
||||
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||
|
||||
self.last_dropout = Identity()
|
||||
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||
|
||||
logger.warning_once(
|
||||
"[DEPRECATION WARNING] `SequenceSummary` is deprecated and will be removed in v4.53. "
|
||||
"Please use model-specific class, e.g. `XLMSequenceSummary`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
||||
The hidden states of the last layer.
|
||||
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
||||
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||
"""
|
||||
if self.summary_type == "last":
|
||||
output = hidden_states[:, -1]
|
||||
elif self.summary_type == "first":
|
||||
output = hidden_states[:, 0]
|
||||
elif self.summary_type == "mean":
|
||||
output = hidden_states.mean(dim=1)
|
||||
elif self.summary_type == "cls_index":
|
||||
if cls_index is None:
|
||||
cls_index = torch.full_like(
|
||||
hidden_states[..., :1, :],
|
||||
hidden_states.shape[-2] - 1,
|
||||
dtype=torch.long,
|
||||
)
|
||||
else:
|
||||
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||
elif self.summary_type == "attn":
|
||||
raise NotImplementedError
|
||||
|
||||
output = self.first_dropout(output)
|
||||
output = self.summary(output)
|
||||
output = self.activation(output)
|
||||
output = self.last_dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
|
||||
"""
|
||||
Recursively unwraps a model from potential containers (as used in distributed training).
|
||||
|
Loading…
Reference in New Issue
Block a user