mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
571 lines
28 KiB
Python
571 lines
28 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" TF 2.0 OpenAI GPT model."""
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import collections
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
from io import open
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
|
|
TFSequenceSummary, shape_list)
|
|
from .configuration_openai import OpenAIGPTConfig
|
|
from .file_utils import add_start_docstrings
|
|
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"}
|
|
|
|
|
|
def load_openai_gpt_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
# build the network
|
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
tf_inputs = tf.constant(inputs_list)
|
|
tfo = tf_model(tf_inputs, training=False)
|
|
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
|
|
|
|
def gelu(x):
|
|
"""Gaussian Error Linear Unit.
|
|
This is a smoother version of the RELU.
|
|
Original paper: https://arxiv.org/abs/1606.08415
|
|
Args:
|
|
x: float Tensor to perform activation.
|
|
Returns:
|
|
`x` with the GELU activation applied.
|
|
"""
|
|
cdf = 0.5 * (1.0 + tf.tanh(
|
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
|
return x * cdf
|
|
|
|
|
|
def swish(x):
|
|
return x * tf.math.sigmoid(x)
|
|
|
|
|
|
ACT_FNS = {"gelu": tf.keras.layers.Activation(gelu),
|
|
"relu": tf.keras.activations.relu,
|
|
"swish": tf.keras.layers.Activation(swish)}
|
|
|
|
|
|
class TFAttention(tf.keras.layers.Layer):
|
|
def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
|
|
super(TFAttention, self).__init__(**kwargs)
|
|
self.output_attentions = config.output_attentions
|
|
|
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
|
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
|
assert n_state % config.n_head == 0
|
|
self.n_ctx = n_ctx
|
|
self.n_head = config.n_head
|
|
self.split_size = n_state
|
|
self.scale = scale
|
|
|
|
self.c_attn = TFConv1D(n_state * 3, nx, name='c_attn')
|
|
self.c_proj = TFConv1D(n_state, nx, name='c_proj')
|
|
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
|
|
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads):
|
|
pass
|
|
|
|
@staticmethod
|
|
def causal_attention_mask(nd, ns, dtype):
|
|
"""1's in the lower triangle, counting from the lower right corner.
|
|
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
|
|
"""
|
|
i = tf.range(nd)[:,None]
|
|
j = tf.range(ns)
|
|
m = i >= j - ns + nd
|
|
return tf.cast(m, dtype)
|
|
|
|
def _attn(self, inputs, training=False):
|
|
q, k, v, attention_mask, head_mask = inputs
|
|
# q, k, v have shape [batch, heads, sequence, features]
|
|
w = tf.matmul(q, k, transpose_b=True)
|
|
if self.scale:
|
|
dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores
|
|
w = w / tf.math.sqrt(dk)
|
|
|
|
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
|
_, _, nd, ns = shape_list(w)
|
|
b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
|
|
b = tf.reshape(b, [1, 1, nd, ns])
|
|
w = w * b - 1e4 * (1 - b)
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask
|
|
w = w + attention_mask
|
|
|
|
w = tf.nn.softmax(w, axis=-1)
|
|
w = self.attn_dropout(w, training=training)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
w = w * head_mask
|
|
|
|
outputs = [tf.matmul(w, v)]
|
|
if self.output_attentions:
|
|
outputs.append(w)
|
|
return outputs
|
|
|
|
def merge_heads(self, x):
|
|
x = tf.transpose(x, [0, 2, 1, 3])
|
|
x_shape = shape_list(x)
|
|
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
|
|
return tf.reshape(x, new_x_shape)
|
|
|
|
def split_heads(self, x):
|
|
x_shape = shape_list(x)
|
|
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
|
|
x = tf.reshape(x, new_x_shape)
|
|
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
|
|
|
def call(self, inputs, training=False):
|
|
x, attention_mask, head_mask = inputs
|
|
|
|
x = self.c_attn(x)
|
|
query, key, value = tf.split(x, 3, axis=2)
|
|
query = self.split_heads(query)
|
|
key = self.split_heads(key)
|
|
value = self.split_heads(value)
|
|
|
|
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
|
|
a = attn_outputs[0]
|
|
|
|
a = self.merge_heads(a)
|
|
a = self.c_proj(a)
|
|
a = self.resid_dropout(a, training=training)
|
|
|
|
outputs = [a] + attn_outputs[1:]
|
|
return outputs # a, (attentions)
|
|
|
|
|
|
class TFMLP(tf.keras.layers.Layer):
|
|
def __init__(self, n_state, config, **kwargs):
|
|
super(TFMLP, self).__init__(**kwargs)
|
|
nx = config.n_embd
|
|
self.c_fc = TFConv1D(n_state, nx, name='c_fc')
|
|
self.c_proj = TFConv1D(nx, n_state, name='c_proj')
|
|
self.act = gelu
|
|
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
|
|
|
def call(self, x, training=False):
|
|
h = self.act(self.c_fc(x))
|
|
h2 = self.c_proj(h)
|
|
h2 = self.dropout(h2, training=training)
|
|
return h2
|
|
|
|
|
|
class TFBlock(tf.keras.layers.Layer):
|
|
def __init__(self, n_ctx, config, scale=False, **kwargs):
|
|
super(TFBlock, self).__init__(**kwargs)
|
|
nx = config.n_embd
|
|
self.attn = TFAttention(nx, n_ctx, config, scale, name='attn')
|
|
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_1')
|
|
self.mlp = TFMLP(4 * nx, config, name='mlp')
|
|
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
|
|
|
|
def call(self, inputs, training=False):
|
|
x, attention_mask, head_mask = inputs
|
|
|
|
output_attn = self.attn([x, attention_mask, head_mask], training=training)
|
|
a = output_attn[0] # output_attn: a, (attentions)
|
|
|
|
n = self.ln_1(x + a)
|
|
m = self.mlp(n, training=training)
|
|
h = self.ln_2(n + m)
|
|
|
|
outputs = [h] + output_attn[1:]
|
|
return outputs # x, (attentions)
|
|
|
|
|
|
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFOpenAIGPTMainLayer, self).__init__(config, *inputs, **kwargs)
|
|
self.output_hidden_states = config.output_hidden_states
|
|
self.output_attentions = config.output_attentions
|
|
self.num_hidden_layers = config.n_layer
|
|
self.vocab_size = config.vocab_size
|
|
self.n_embd = config.n_embd
|
|
|
|
self.tokens_embed = TFSharedEmbeddings(config.vocab_size, config.n_embd, name='tokens_embed')
|
|
self.positions_embed = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='positions_embed')
|
|
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
|
self.h = [TFBlock(config.n_ctx,
|
|
config,
|
|
scale=True,
|
|
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
|
|
|
def _resize_token_embeddings(self, new_num_tokens):
|
|
raise NotImplementedError
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
""" Prunes heads of the model.
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
|
if isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
elif isinstance(inputs, dict):
|
|
input_ids = inputs.get('input_ids')
|
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
|
position_ids = inputs.get('position_ids', position_ids)
|
|
head_mask = inputs.get('head_mask', head_mask)
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
else:
|
|
input_ids = inputs
|
|
|
|
if position_ids is None:
|
|
position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :]
|
|
|
|
if attention_mask is not None:
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# this attention mask is more simple than the triangular masking of causal attention
|
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
|
|
attention_mask = tf.cast(attention_mask, tf.float32)
|
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
|
else:
|
|
attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if not head_mask is None:
|
|
raise NotImplementedError
|
|
else:
|
|
head_mask = [None] * self.num_hidden_layers
|
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
|
|
input_shape = shape_list(input_ids)
|
|
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
|
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
|
|
|
inputs_embeds = self.tokens_embed(input_ids, mode='embedding')
|
|
position_embeds = self.positions_embed(position_ids)
|
|
if token_type_ids is not None:
|
|
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
|
token_type_embeds = self.tokens_embed(token_type_ids, mode='embedding')
|
|
else:
|
|
token_type_embeds = 0
|
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
|
hidden_states = self.drop(hidden_states, training=training)
|
|
|
|
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
|
|
|
all_attentions = []
|
|
all_hidden_states = ()
|
|
for i, block in enumerate(self.h):
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
|
|
|
outputs = block([hidden_states, attention_mask, head_mask[i]], training=training)
|
|
hidden_states = outputs[0]
|
|
if self.output_attentions:
|
|
all_attentions.append(outputs[1])
|
|
|
|
hidden_states = tf.reshape(hidden_states, output_shape)
|
|
# Add last hidden state
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
outputs = (hidden_states,)
|
|
if self.output_hidden_states:
|
|
outputs = outputs + (all_hidden_states,)
|
|
if self.output_attentions:
|
|
# let the number of heads free (-1) so we can extract attention even after head pruning
|
|
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
|
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
|
outputs = outputs + (all_attentions,)
|
|
return outputs # last hidden state, (all hidden_states), (attentions)
|
|
|
|
|
|
class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
config_class = OpenAIGPTConfig
|
|
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
load_pt_weights = load_openai_gpt_pt_weights_in_tf2
|
|
base_model_prefix = "transformer"
|
|
|
|
|
|
OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
|
|
`Improving Language Understanding by Generative Pre-Training`_
|
|
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
|
It's a causal (unidirectional) transformer pre-trained using language modeling on a large
|
|
corpus will long range dependencies, the Toronto Book Corpus.
|
|
|
|
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
|
|
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
|
|
|
.. _`Improving Language Understanding by Generative Pre-Training`:
|
|
https://openai.com/blog/language-unsupervised/
|
|
|
|
.. _`tf.keras.Model`:
|
|
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
|
|
|
|
Note on the model inputs:
|
|
TF 2.0 models accepts two formats as inputs:
|
|
|
|
- having all inputs as keyword arguments (like PyTorch models), or
|
|
- having all inputs as a list, tuple or dict in the first positional arguments.
|
|
|
|
This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
|
|
|
|
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
|
|
|
|
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
|
|
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
|
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
|
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
|
|
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
|
|
|
Parameters:
|
|
config (:class:`~pytorch_transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
|
**input_ids**: ```Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
|
the right rather than the left.
|
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**token_type_ids**: (`optional`) ```Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
|
The embeddings from these tokens will be summed with the respective token embeddings.
|
|
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices)
|
|
**position_ids**: (`optional`) ```Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
**head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
"""
|
|
|
|
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
|
|
OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
|
class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
Sequence of hidden-states at the last layer of the model.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from pytorch_transformers import OpenAIGPTTokenizer, TFOpenAIGPTModel
|
|
|
|
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
|
model = TFOpenAIGPTModel.from_pretrained('openai-gpt')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFOpenAIGPTModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
return outputs
|
|
|
|
|
|
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling head on top
|
|
(linear layer with weights tied to the input embeddings). """, OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
|
class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from pytorch_transformers import OpenAIGPTTokenizer, TFOpenAIGPTLMHeadModel
|
|
|
|
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
|
model = TFOpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
|
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
|
outputs = model(input_ids)
|
|
logits = outputs[0]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
|
|
|
def call(self, inputs, **kwargs):
|
|
transformer_outputs = self.transformer(inputs, **kwargs)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
|
|
|
outputs = (lm_logits,) + transformer_outputs[1:]
|
|
|
|
return outputs # lm_logits, (all hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
|
|
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
|
The language modeling head has its weights tied to the input embeddings,
|
|
the classification head takes as input the input of a specified classification token index in the input sequence).
|
|
""", OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
|
class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|
r"""
|
|
**mc_token_ids**: (`optional`, default to index of the last token of the input) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, num_choices)``:
|
|
Index of the classification token in each input sequence.
|
|
Selected in the range ``[0, input_ids.size(-1) - 1[``.
|
|
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
|
|
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import tensorflow as tf
|
|
from pytorch_transformers import OpenAIGPTTokenizer, TFOpenAIGPTDoubleHeadsModel
|
|
|
|
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
|
model = TFOpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt')
|
|
|
|
# Add a [CLS] to the vocabulary (we should train it also!)
|
|
# This option is currently not implemented in TF 2.0
|
|
raise NotImplementedError
|
|
tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
|
model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
|
|
print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
|
|
|
|
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
|
input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :] # Batch size 1, 2 choices
|
|
mc_token_ids = tf.constant([input_ids.size(-1), input_ids.size(-1)])[None, :] # Batch size 1
|
|
outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
|
lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFOpenAIGPTDoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
|
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
|
|
|
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
|
if isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
mc_token_ids = inputs[5] if len(inputs) > 5 else mc_token_ids
|
|
assert len(inputs) <= 6, "Too many inputs."
|
|
elif isinstance(inputs, dict):
|
|
input_ids = inputs.get('input_ids')
|
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
|
position_ids = inputs.get('position_ids', position_ids)
|
|
head_mask = inputs.get('head_mask', head_mask)
|
|
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
|
assert len(inputs) <= 6, "Too many inputs."
|
|
else:
|
|
input_ids = inputs
|
|
|
|
input_shapes = shape_list(input_ids)
|
|
|
|
seq_length = input_shapes[-1]
|
|
|
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
|
|
|
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
|
|
|
transformer_outputs = self.transformer(flat_inputs, training=training)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
|
|
|
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
|
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
|
|
|
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
|
|
|
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
|
|
|
return outputs # lm logits, mc logits, (all hidden_states), (attentions)
|