mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-08 07:10:06 +06:00
505 lines
27 KiB
Python
505 lines
27 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" TF 2.0 XXX model. """
|
|
|
|
####################################################
|
|
# In this template, replace all the XXX (various casings) with your model name
|
|
####################################################
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
from io import open
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from .configuration_xxx import XxxConfig
|
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
|
from .file_utils import add_start_docstrings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
####################################################
|
|
# This dict contrains shortcut names and associated url
|
|
# for the pretrained weights provided with the models
|
|
####################################################
|
|
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
|
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
|
}
|
|
|
|
####################################################
|
|
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
|
|
# - tf.keras.layers.Layer for the layers and
|
|
# - TFPreTrainedModel for the models (itself a sub-class of tf.keras.Model)
|
|
####################################################
|
|
|
|
####################################################
|
|
# Here is an example of typical layer in a TF 2.0 model of the library
|
|
# The classes are usually identical to the PyTorch ones and prefixed with 'TF'.
|
|
#
|
|
# Note that class __init__ parameters includes **kwargs (send to 'super').
|
|
# This let us have a control on class scope and variable names:
|
|
# More precisely, we set the names of the class attributes (lower level layers) to
|
|
# to the equivalent attributes names in the PyTorch model so we can have equivalent
|
|
# class and scope structure between PyTorch and TF 2.0 models and easily load one in the other.
|
|
#
|
|
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
|
####################################################
|
|
class TFXxxLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, **kwargs):
|
|
super(TFXxxLayer, self).__init__(**kwargs)
|
|
self.attention = TFXxxAttention(config, name='attention')
|
|
self.intermediate = TFXxxIntermediate(config, name='intermediate')
|
|
self.transformer_output = TFXxxOutput(config, name='output')
|
|
|
|
def call(self, inputs, training=False):
|
|
hidden_states, attention_mask, head_mask = inputs
|
|
|
|
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
|
|
attention_output = attention_outputs[0]
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.transformer_output([intermediate_output, attention_output], training=training)
|
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
####################################################
|
|
# The full model without a specific pretrained or finetuning head is
|
|
# provided as a tf.keras.layers.Layer usually called "TFXxxMainLayer"
|
|
####################################################
|
|
class TFXxxMainLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, **kwargs):
|
|
super(TFXxxMainLayer, self).__init__(**kwargs)
|
|
|
|
def _resize_token_embeddings(self, new_num_tokens):
|
|
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
|
|
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
|
# We allow three types of multi-inputs:
|
|
# - traditional keyword arguments in the call method
|
|
# - all the arguments provided as a dict in the first positional argument of call
|
|
# - all the arguments provided as a list/tuple (ordered) in the first positional argument of call
|
|
# The last two options are useful to use the tf.keras fit() method.
|
|
|
|
if isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
elif isinstance(inputs, dict):
|
|
input_ids = inputs.get('input_ids')
|
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
|
position_ids = inputs.get('position_ids', position_ids)
|
|
head_mask = inputs.get('head_mask', head_mask)
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
else:
|
|
input_ids = inputs
|
|
|
|
if attention_mask is None:
|
|
attention_mask = tf.fill(shape_list(input_ids), 1)
|
|
if token_type_ids is None:
|
|
token_type_ids = tf.fill(shape_list(input_ids), 0)
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# this attention mask is more simple than the triangular masking of causal attention
|
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
|
|
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if not head_mask is None:
|
|
raise NotImplementedError
|
|
else:
|
|
head_mask = [None] * self.num_hidden_layers
|
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
|
|
##################################
|
|
# Replace this with your model code
|
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
|
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
|
sequence_output = encoder_outputs[0]
|
|
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
|
|
|
return outputs # sequence_output, (hidden_states), (attentions)
|
|
|
|
|
|
####################################################
|
|
# TFXxxPreTrainedModel is a sub-class of tf.keras.Model
|
|
# which take care of loading and saving pretrained weights
|
|
# and various common utilities.
|
|
# Here you just need to specify a few (self-explanatory)
|
|
# pointers for your model.
|
|
####################################################
|
|
class TFXxxPreTrainedModel(TFPreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
config_class = XxxConfig
|
|
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
base_model_prefix = "transformer"
|
|
|
|
|
|
XXX_START_DOCSTRING = r""" The XXX model was proposed in
|
|
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
|
|
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
|
|
pre-trained using a combination of masked language modeling objective and next sentence prediction
|
|
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
|
|
|
|
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
|
|
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
|
|
|
.. _`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
|
|
https://arxiv.org/abs/1810.04805
|
|
|
|
.. _`tf.keras.Model`:
|
|
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
|
|
|
|
Note on the model inputs:
|
|
TF 2.0 models accepts two formats as inputs:
|
|
|
|
- having all inputs as keyword arguments (like PyTorch models), or
|
|
- having all inputs as a list, tuple or dict in the first positional arguments.
|
|
|
|
This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
|
|
|
|
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
|
|
|
|
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
|
|
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
|
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
|
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
|
|
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.XxxConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
XXX_INPUTS_DOCSTRING = r"""
|
|
Inputs:
|
|
**input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
To match pre-training, XXX input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
|
|
|
(a) For sequence pairs:
|
|
|
|
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
|
|
|
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
|
|
|
(b) For single sequences:
|
|
|
|
``tokens: [CLS] the dog is hairy . [SEP]``
|
|
|
|
``token_type_ids: 0 0 0 0 0 0 0``
|
|
|
|
Xxx is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
|
the right rather than the left.
|
|
|
|
Indices can be obtained using :class:`transformers.XxxTokenizer`.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**token_type_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Segment token indices to indicate first and second portions of the inputs.
|
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
|
corresponds to a `sentence B` token
|
|
(see `XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
|
**position_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
**head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
"""
|
|
|
|
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
|
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
|
class TFXxxModel(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
**pooler_output**: ``tf.Tensor`` of shape ``(batch_size, hidden_size)``
|
|
Last layer hidden-state of the first token of the sequence (classification token)
|
|
further processed by a Linear layer and a Tanh activation function. The Linear
|
|
layer weights are trained from the next sentence prediction (classification)
|
|
objective during Xxx pretraining. This output is usually *not* a good summary
|
|
of the semantic content of the input, you're often better with averaging or pooling
|
|
the sequence of hidden-states for the whole input sequence.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxModel
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxModel.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFXxxMainLayer(config, name='transformer')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
return outputs
|
|
|
|
|
|
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
|
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
|
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForMaskedLM
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForMaskedLM.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
prediction_scores = outputs[0]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
|
|
|
self.transformer = TFXxxMainLayer(config, name='transformer')
|
|
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name='mlm')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
|
|
|
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
|
|
|
return outputs # prediction_scores, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
|
the pooled output) e.g. for GLUE tasks. """,
|
|
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
|
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**logits**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
|
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForSequenceClassification
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForSequenceClassification.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
logits = outputs[0]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name='transformer')
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
|
kernel_initializer=get_initializer(config.initializer_range),
|
|
name='classifier')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
|
|
logits = self.classifier(pooled_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
return outputs # logits, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
|
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
|
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
|
Classification scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForTokenClassification
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForTokenClassification.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
scores = outputs[0]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name='transformer')
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
|
kernel_initializer=get_initializer(config.initializer_range),
|
|
name='classifier')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
|
|
logits = self.classifier(sequence_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
return outputs # scores, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
|
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
|
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**start_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
|
|
Span-start scores (before SoftMax).
|
|
**end_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
|
|
Span-end scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForQuestionAnswering
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForQuestionAnswering.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
start_scores, end_scores = outputs[:2]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name='transformer')
|
|
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
|
|
kernel_initializer=get_initializer(config.initializer_range),
|
|
name='qa_outputs')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
|
start_logits = tf.squeeze(start_logits, axis=-1)
|
|
end_logits = tf.squeeze(end_logits, axis=-1)
|
|
|
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
|
|
|
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|