mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 14:20:04 +06:00

* remove the implied defaults to :obj:`None` * fix bug in the original * replace to :obj:`True`, :obj:`False`
796 lines
31 KiB
Python
796 lines
31 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 XXX Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch XXX model. """
|
|
|
|
####################################################
|
|
# In this template, replace all the XXX (various casings) with your model name
|
|
####################################################
|
|
|
|
|
|
import logging
|
|
import os
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import CrossEntropyLoss, MSELoss
|
|
|
|
from .configuration_xxx import XxxConfig
|
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
|
from .modeling_outputs import (
|
|
BaseModelOutputWithPooling,
|
|
MaskedLMOutput,
|
|
MultipleChoiceModelOutput,
|
|
QuestionAnsweringModelOutput,
|
|
SequenceClassifierOutput,
|
|
TokenClassifierOutput,
|
|
)
|
|
from .modeling_utils import PreTrainedModel
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_CONFIG_FOR_DOC = "XXXConfig"
|
|
_TOKENIZER_FOR_DOC = "XXXTokenizer"
|
|
|
|
####################################################
|
|
# This list contrains shortcut names for some of
|
|
# the pretrained weights provided with the models
|
|
####################################################
|
|
XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
"xxx-base-uncased",
|
|
"xxx-large-uncased",
|
|
]
|
|
|
|
|
|
####################################################
|
|
# This is a conversion method from TF 1.0 to PyTorch
|
|
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
|
|
####################################################
|
|
def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|
"""Load tf checkpoints in a pytorch model."""
|
|
try:
|
|
import re
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
logger.error(
|
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
|
"https://www.tensorflow.org/install/ for installation instructions."
|
|
)
|
|
raise
|
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
|
# Load weights from TF model
|
|
init_vars = tf.train.list_variables(tf_path)
|
|
names = []
|
|
arrays = []
|
|
for name, shape in init_vars:
|
|
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
|
array = tf.train.load_variable(tf_path, name)
|
|
names.append(name)
|
|
arrays.append(array)
|
|
|
|
for name, array in zip(names, arrays):
|
|
name = name.split("/")
|
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
# which are not required for using pretrained model
|
|
if any(
|
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
|
for n in name
|
|
):
|
|
logger.info("Skipping {}".format("/".join(name)))
|
|
continue
|
|
pointer = model
|
|
for m_name in name:
|
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
|
scope_names = re.split(r"_(\d+)", m_name)
|
|
else:
|
|
scope_names = [m_name]
|
|
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
|
pointer = getattr(pointer, "weight")
|
|
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
|
pointer = getattr(pointer, "bias")
|
|
elif scope_names[0] == "output_weights":
|
|
pointer = getattr(pointer, "weight")
|
|
elif scope_names[0] == "squad":
|
|
pointer = getattr(pointer, "classifier")
|
|
else:
|
|
try:
|
|
pointer = getattr(pointer, scope_names[0])
|
|
except AttributeError:
|
|
logger.info("Skipping {}".format("/".join(name)))
|
|
continue
|
|
if len(scope_names) >= 2:
|
|
num = int(scope_names[1])
|
|
pointer = pointer[num]
|
|
if m_name[-11:] == "_embeddings":
|
|
pointer = getattr(pointer, "weight")
|
|
elif m_name == "kernel":
|
|
array = np.transpose(array)
|
|
try:
|
|
assert (
|
|
pointer.shape == array.shape
|
|
), f"Pointer and array have mismatched shapes {pointer.shape} and {array.shape}"
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
logger.info("Initialize PyTorch weight {}".format(name))
|
|
pointer.data = torch.from_numpy(array)
|
|
return model
|
|
|
|
|
|
####################################################
|
|
# PyTorch Models are constructed by sub-classing
|
|
# - torch.nn.Module for the layers and
|
|
# - PreTrainedModel for the models (itself a sub-class of torch.nn.Module)
|
|
####################################################
|
|
|
|
####################################################
|
|
# Here is an example of typical layer in a PyTorch model of the library
|
|
# The classes are usually identical to the TF 2.0 ones without the 'TF' prefix.
|
|
#
|
|
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
|
####################################################
|
|
|
|
XxxAttention = nn.Module
|
|
|
|
XxxIntermediate = nn.Module
|
|
|
|
XxxOutput = nn.Module
|
|
|
|
|
|
class XxxLayer(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.attention = XxxAttention(config)
|
|
self.intermediate = XxxIntermediate(config)
|
|
self.output = XxxOutput(config)
|
|
|
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
|
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
|
attention_output = attention_outputs[0]
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
####################################################
|
|
# PreTrainedModel is a sub-class of torch.nn.Module
|
|
# which take care of loading and saving pretrained weights
|
|
# and various common utilities.
|
|
#
|
|
# Here you just need to specify a few (self-explanatory)
|
|
# pointers for your model and the weights initialization
|
|
# method if its not fully covered by PreTrainedModel's default method
|
|
####################################################
|
|
|
|
XxxLayerNorm = torch.nn.LayerNorm
|
|
|
|
XxxEmbeddings = nn.Module
|
|
|
|
XxxEncoder = nn.Module
|
|
|
|
XxxPooler = nn.Module
|
|
|
|
|
|
class XxxPreTrainedModel(PreTrainedModel):
|
|
"""An abstract class to handle weights initialization and
|
|
a simple interface for downloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = XxxConfig
|
|
load_tf_weights = load_tf_weights_in_xxx
|
|
base_model_prefix = "transformer"
|
|
|
|
def _init_weights(self, module):
|
|
""" Initialize the weights """
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
elif isinstance(module, XxxLayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
|
|
XXX_START_DOCSTRING = r""" The XXX model was proposed in
|
|
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding
|
|
<https://arxiv.org/abs/1810.04805>`__ by....
|
|
|
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
|
usage and behavior.
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.XxxConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
XXX_INPUTS_DOCSTRING = r"""
|
|
Inputs:
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using :class:`transformers.XxxTokenizer`.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__
|
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
|
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
|
Segment token indices to indicate first and second portions of the inputs.
|
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
|
corresponds to a `sentence B` token
|
|
|
|
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
|
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
|
|
`What are position IDs? <../glossary.html#position-ids>`_
|
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
output_attentions (:obj:`bool`, `optional`):
|
|
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
|
output_hidden_states (:obj:`bool`, `optional`):
|
|
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
|
return_dict (:obj:`bool`, `optional`):
|
|
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
|
|
plain tuple.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare XXX Model transformer outputting raw hidden-states without any specific head on top.",
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class XxxModel(XxxPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.embeddings = XxxEmbeddings(config)
|
|
self.encoder = XxxEncoder(config)
|
|
self.pooler = XxxPooler(config)
|
|
|
|
self.init_weights()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
self.embeddings.word_embeddings = new_embeddings
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
"""Prunes heads of the model.
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
See base class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-uncased",
|
|
output_type=BaseModelOutputWithPooling,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(input_shape, device=device)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
##################################
|
|
# Replace this with your model code
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
|
)
|
|
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output)
|
|
|
|
if not return_dict:
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings("""XXX Model with a `language modeling` head on top. """, XXX_START_DOCSTRING)
|
|
class XxxForMaskedLM(XxxPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.transformer = XxxModel(config)
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
|
|
|
self.init_weights()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.lm_head
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-uncased",
|
|
output_type=MaskedLMOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the masked language modeling loss.
|
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
|
in ``[0, ..., config.vocab_size]``
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.transformer(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
masked_lm_loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
|
|
|
return MaskedLMOutput(
|
|
loss=masked_lm_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
|
the pooled output) e.g. for GLUE tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class XxxForSequenceClassification(XxxPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = XxxModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
|
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-uncased",
|
|
output_type=SequenceClassifierOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for computing the sequence classification/regression loss.
|
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
|
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.transformer(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
if self.num_labels == 1:
|
|
# We are doing regression
|
|
loss_fct = MSELoss()
|
|
loss = loss_fct(logits.view(-1), labels.view(-1))
|
|
else:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model with a multiple choice classification head on top (a linear layer on top of
|
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class XxxForMultipleChoice(XxxPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.transformer = XxxModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
|
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-uncased",
|
|
output_type=MultipleChoiceModelOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for computing the multiple choice classification loss.
|
|
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
|
of the input tensors. (see `input_ids` above)
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
|
inputs_embeds = (
|
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
|
if inputs_embeds is not None
|
|
else None
|
|
)
|
|
|
|
outputs = self.transformer(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
reshaped_logits = logits.view(-1, num_choices)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(reshaped_logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (reshaped_logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return MultipleChoiceModelOutput(
|
|
loss=loss,
|
|
logits=reshaped_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model with a token classification head on top (a linear layer on top of
|
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class XxxForTokenClassification(XxxPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = XxxModel(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-uncased",
|
|
output_type=TokenClassifierOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
labels=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the token classification loss.
|
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.transformer(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
# Only keep active parts of the loss
|
|
if attention_mask is not None:
|
|
active_loss = attention_mask.view(-1) == 1
|
|
active_logits = logits.view(-1, self.num_labels)
|
|
active_labels = torch.where(
|
|
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
|
)
|
|
loss = loss_fct(active_logits, active_labels)
|
|
else:
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
|
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class XxxForQuestionAnswering(XxxPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = XxxModel(config)
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
self.init_weights()
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-uncased",
|
|
output_type=QuestionAnsweringModelOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
start_positions=None,
|
|
end_positions=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
|
Position outside of the sequence are not taken into account for computing the loss.
|
|
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
|
Position outside of the sequence are not taken into account for computing the loss.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.transformer(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1)
|
|
end_logits = end_logits.squeeze(-1)
|
|
|
|
total_loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
# If we are on multi-GPU, split add a dimension
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1)
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions.clamp_(0, ignored_index)
|
|
end_positions.clamp_(0, ignored_index)
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
if not return_dict:
|
|
output = (start_logits, end_logits) + outputs[2:]
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
loss=total_loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|