mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00

This is the result of: $ black --line-length 119 examples templates transformers utils hubconf.py setup.py There's a lot of fairly long lines in the project. As a consequence, I'm picking the longest widely accepted line length, 119 characters. This is also Thomas' preference, because it allows for explicit variable names, to make the code easier to understand.
529 lines
27 KiB
Python
529 lines
27 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" TF 2.0 XXX model. """
|
|
|
|
####################################################
|
|
# In this template, replace all the XXX (various casings) with your model name
|
|
####################################################
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
import copy
|
|
import itertools
|
|
from io import open
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from .configuration_xxx import XxxConfig
|
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
|
from .file_utils import add_start_docstrings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
####################################################
|
|
# This dict contrains shortcut names and associated url
|
|
# for the pretrained weights provided with the models
|
|
####################################################
|
|
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
|
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
|
}
|
|
|
|
####################################################
|
|
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
|
|
# - tf.keras.layers.Layer for the layers and
|
|
# - TFPreTrainedModel for the models (itself a sub-class of tf.keras.Model)
|
|
####################################################
|
|
|
|
####################################################
|
|
# Here is an example of typical layer in a TF 2.0 model of the library
|
|
# The classes are usually identical to the PyTorch ones and prefixed with 'TF'.
|
|
#
|
|
# Note that class __init__ parameters includes **kwargs (send to 'super').
|
|
# This let us have a control on class scope and variable names:
|
|
# More precisely, we set the names of the class attributes (lower level layers) to
|
|
# to the equivalent attributes names in the PyTorch model so we can have equivalent
|
|
# class and scope structure between PyTorch and TF 2.0 models and easily load one in the other.
|
|
#
|
|
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
|
####################################################
|
|
class TFXxxLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, **kwargs):
|
|
super(TFXxxLayer, self).__init__(**kwargs)
|
|
self.attention = TFXxxAttention(config, name="attention")
|
|
self.intermediate = TFXxxIntermediate(config, name="intermediate")
|
|
self.transformer_output = TFXxxOutput(config, name="output")
|
|
|
|
def call(self, inputs, training=False):
|
|
hidden_states, attention_mask, head_mask = inputs
|
|
|
|
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
|
|
attention_output = attention_outputs[0]
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.transformer_output([intermediate_output, attention_output], training=training)
|
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
####################################################
|
|
# The full model without a specific pretrained or finetuning head is
|
|
# provided as a tf.keras.layers.Layer usually called "TFXxxMainLayer"
|
|
####################################################
|
|
class TFXxxMainLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, **kwargs):
|
|
super(TFXxxMainLayer, self).__init__(**kwargs)
|
|
|
|
def _resize_token_embeddings(self, new_num_tokens):
|
|
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
|
|
|
def call(
|
|
self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False
|
|
):
|
|
# We allow three types of multi-inputs:
|
|
# - traditional keyword arguments in the call method
|
|
# - all the arguments provided as a dict in the first positional argument of call
|
|
# - all the arguments provided as a list/tuple (ordered) in the first positional argument of call
|
|
# The last two options are useful to use the tf.keras fit() method.
|
|
|
|
if isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
elif isinstance(inputs, dict):
|
|
input_ids = inputs.get("input_ids")
|
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
|
position_ids = inputs.get("position_ids", position_ids)
|
|
head_mask = inputs.get("head_mask", head_mask)
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
else:
|
|
input_ids = inputs
|
|
|
|
if attention_mask is None:
|
|
attention_mask = tf.fill(shape_list(input_ids), 1)
|
|
if token_type_ids is None:
|
|
token_type_ids = tf.fill(shape_list(input_ids), 0)
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# this attention mask is more simple than the triangular masking of causal attention
|
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
|
|
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if not head_mask is None:
|
|
raise NotImplementedError
|
|
else:
|
|
head_mask = [None] * self.num_hidden_layers
|
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
|
|
##################################
|
|
# Replace this with your model code
|
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
|
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
|
sequence_output = encoder_outputs[0]
|
|
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
|
|
|
return outputs # sequence_output, (hidden_states), (attentions)
|
|
|
|
|
|
####################################################
|
|
# TFXxxPreTrainedModel is a sub-class of tf.keras.Model
|
|
# which take care of loading and saving pretrained weights
|
|
# and various common utilities.
|
|
# Here you just need to specify a few (self-explanatory)
|
|
# pointers for your model.
|
|
####################################################
|
|
class TFXxxPreTrainedModel(TFPreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = XxxConfig
|
|
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
base_model_prefix = "transformer"
|
|
|
|
|
|
XXX_START_DOCSTRING = r""" The XXX model was proposed in
|
|
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
|
|
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
|
|
pre-trained using a combination of masked language modeling objective and next sentence prediction
|
|
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
|
|
|
|
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
|
|
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
|
|
|
.. _`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
|
|
https://arxiv.org/abs/1810.04805
|
|
|
|
.. _`tf.keras.Model`:
|
|
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
|
|
|
|
Note on the model inputs:
|
|
TF 2.0 models accepts two formats as inputs:
|
|
|
|
- having all inputs as keyword arguments (like PyTorch models), or
|
|
- having all inputs as a list, tuple or dict in the first positional arguments.
|
|
|
|
This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
|
|
|
|
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
|
|
|
|
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
|
|
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
|
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
|
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
|
|
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.XxxConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
XXX_INPUTS_DOCSTRING = r"""
|
|
Inputs:
|
|
**input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
To match pre-training, XXX input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
|
|
|
(a) For sequence pairs:
|
|
|
|
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
|
|
|
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
|
|
|
(b) For single sequences:
|
|
|
|
``tokens: [CLS] the dog is hairy . [SEP]``
|
|
|
|
``token_type_ids: 0 0 0 0 0 0 0``
|
|
|
|
Xxx is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
|
the right rather than the left.
|
|
|
|
Indices can be obtained using :class:`transformers.XxxTokenizer`.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**token_type_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Segment token indices to indicate first and second portions of the inputs.
|
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
|
corresponds to a `sentence B` token
|
|
(see `XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
|
**position_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
**head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
|
XXX_START_DOCSTRING,
|
|
XXX_INPUTS_DOCSTRING,
|
|
)
|
|
class TFXxxModel(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
**pooler_output**: ``tf.Tensor`` of shape ``(batch_size, hidden_size)``
|
|
Last layer hidden-state of the first token of the sequence (classification token)
|
|
further processed by a Linear layer and a Tanh activation function. The Linear
|
|
layer weights are trained from the next sentence prediction (classification)
|
|
objective during Xxx pretraining. This output is usually *not* a good summary
|
|
of the semantic content of the input, you're often better with averaging or pooling
|
|
the sequence of hidden-states for the whole input sequence.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxModel
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxModel.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
|
|
"""
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
return outputs
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
|
)
|
|
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForMaskedLM
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForMaskedLM.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
prediction_scores = outputs[0]
|
|
|
|
"""
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name="mlm")
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
|
|
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
|
|
|
return outputs # prediction_scores, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
|
the pooled output) e.g. for GLUE tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
XXX_INPUTS_DOCSTRING,
|
|
)
|
|
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**logits**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
|
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForSequenceClassification
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForSequenceClassification.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
logits = outputs[0]
|
|
|
|
"""
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(
|
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
|
)
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
|
|
logits = self.classifier(pooled_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
return outputs # logits, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""Xxx Model with a token classification head on top (a linear layer on top of
|
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
XXX_INPUTS_DOCSTRING,
|
|
)
|
|
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
|
Classification scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForTokenClassification
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForTokenClassification.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
scores = outputs[0]
|
|
|
|
"""
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(
|
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
|
)
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
|
logits = self.classifier(sequence_output)
|
|
|
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
|
|
|
return outputs # scores, (hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
|
XXX_START_DOCSTRING,
|
|
XXX_INPUTS_DOCSTRING,
|
|
)
|
|
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**start_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
|
|
Span-start scores (before SoftMax).
|
|
**end_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
|
|
Span-end scores (before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from transformers import XxxTokenizer, TFXxxForQuestionAnswering
|
|
|
|
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
|
|
model = TFXxxForQuestionAnswering.from_pretrained('xxx-base-uncased')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
start_scores, end_scores = outputs[:2]
|
|
|
|
"""
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.qa_outputs = tf.keras.layers.Dense(
|
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
|
)
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
|
start_logits = tf.squeeze(start_logits, axis=-1)
|
|
end_logits = tf.squeeze(end_logits, axis=-1)
|
|
|
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
|
|
|
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|