mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
724 lines
34 KiB
Python
724 lines
34 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" TF 2.0 OpenAI GPT-2 model. """
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import collections
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
from io import open
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list
|
|
from .configuration_gpt2 import GPT2Config
|
|
from .file_utils import add_start_docstrings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5",
|
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5",
|
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"}
|
|
|
|
|
|
def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
|
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
|
|
We use HDF5 to easily do transfer learning
|
|
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
|
"""
|
|
try:
|
|
import re
|
|
import torch
|
|
import numpy
|
|
from tensorflow.python.keras import backend as K
|
|
except ImportError:
|
|
logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
|
|
"https://pytorch.org/ for installation instructions.")
|
|
raise
|
|
|
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
|
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
|
# Load pytorch model
|
|
state_dict = torch.load(pt_path, map_location='cpu')
|
|
|
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
tf_inputs = tf.constant(inputs_list)
|
|
tfo = tf_model(tf_inputs, training=False) # build the network
|
|
|
|
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
|
weight_value_tuples = []
|
|
for symbolic_weight in symbolic_weights:
|
|
name = symbolic_weight.name
|
|
name = name.replace(':0', '')
|
|
name = name.replace('h_', 'h/')
|
|
name = name.split('/')
|
|
name = name[2:]
|
|
|
|
transpose = bool(name[-1] == 'kernel')
|
|
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
|
|
name[-1] = 'weight'
|
|
if name[-1] == 'beta':
|
|
name[-1] = 'bias'
|
|
|
|
name = '.'.join(name)
|
|
assert name in state_dict, "Weight {} not in PyTorch model".format(name)
|
|
array = state_dict[name].numpy()
|
|
|
|
if transpose:
|
|
array = numpy.transpose(array)
|
|
|
|
if len(symbolic_weight.shape) > len(array.shape):
|
|
array = array[None, ...]
|
|
if len(symbolic_weight.shape) < len(array.shape):
|
|
array = np.squeeze(array)
|
|
|
|
try:
|
|
assert list(symbolic_weight.shape) == list(array.shape)
|
|
except AssertionError as e:
|
|
e.args += (symbolic_weight.shape, array.shape)
|
|
raise e
|
|
|
|
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
|
|
|
|
weight_value_tuples.append((symbolic_weight, array))
|
|
|
|
state_dict.pop(name)
|
|
|
|
K.batch_set_value(weight_value_tuples)
|
|
|
|
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
|
|
|
assert not state_dict, "Weights not loaded: {}".format(list(state_dict.keys()))
|
|
|
|
return tf_model
|
|
|
|
|
|
def gelu(x):
|
|
"""Gaussian Error Linear Unit.
|
|
This is a smoother version of the RELU.
|
|
Original paper: https://arxiv.org/abs/1606.08415
|
|
Args:
|
|
x: float Tensor to perform activation.
|
|
Returns:
|
|
`x` with the GELU activation applied.
|
|
"""
|
|
cdf = 0.5 * (1.0 + tf.tanh(
|
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
|
return x * cdf
|
|
|
|
|
|
class TFAttention(tf.keras.layers.Layer):
|
|
def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
|
|
super(TFAttention, self).__init__(**kwargs)
|
|
self.output_attentions = config.output_attentions
|
|
|
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
|
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
|
assert n_state % config.n_head == 0
|
|
self.n_ctx = n_ctx
|
|
self.n_head = config.n_head
|
|
self.split_size = n_state
|
|
self.scale = scale
|
|
|
|
self.c_attn = TFConv1D(n_state * 3, nx, name='c_attn')
|
|
self.c_proj = TFConv1D(n_state, nx, name='c_proj')
|
|
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
|
|
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads):
|
|
pass
|
|
|
|
@staticmethod
|
|
def causal_attention_mask(nd, ns, dtype):
|
|
"""1's in the lower triangle, counting from the lower right corner.
|
|
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
|
|
"""
|
|
i = tf.range(nd)[:,None]
|
|
j = tf.range(ns)
|
|
m = i >= j - ns + nd
|
|
return tf.cast(m, dtype)
|
|
|
|
def _attn(self, inputs, training=False):
|
|
q, k, v, attention_mask, head_mask = inputs
|
|
# q, k, v have shape [batch, heads, sequence, features]
|
|
w = tf.matmul(q, k, transpose_b=True)
|
|
if self.scale:
|
|
dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores
|
|
w = w / tf.math.sqrt(dk)
|
|
|
|
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
|
_, _, nd, ns = shape_list(w)
|
|
b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
|
|
b = tf.reshape(b, [1, 1, nd, ns])
|
|
w = w * b - 1e4 * (1 - b)
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask
|
|
w = w + attention_mask
|
|
|
|
w = tf.nn.softmax(w, axis=-1)
|
|
if training:
|
|
w = self.attn_dropout(w)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
w = w * head_mask
|
|
|
|
outputs = [tf.matmul(w, v)]
|
|
if self.output_attentions:
|
|
outputs.append(w)
|
|
return outputs
|
|
|
|
def merge_heads(self, x):
|
|
x = tf.transpose(x, [0, 2, 1, 3])
|
|
x_shape = shape_list(x)
|
|
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
|
|
return tf.reshape(x, new_x_shape)
|
|
|
|
def split_heads(self, x):
|
|
x_shape = shape_list(x)
|
|
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
|
|
x = tf.reshape(x, new_x_shape)
|
|
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
|
|
|
def call(self, inputs, training=False):
|
|
x, layer_past, attention_mask, head_mask = inputs
|
|
|
|
x = self.c_attn(x)
|
|
query, key, value = tf.split(x, 3, axis=2)
|
|
query = self.split_heads(query)
|
|
key = self.split_heads(key)
|
|
value = self.split_heads(value)
|
|
if layer_past is not None:
|
|
past_key, past_value = tf.unstack(layer_past, axis=1)
|
|
key = tf.concat([past_key, key], axis=-2)
|
|
value = tf.concat([past_value, value], axis=-2)
|
|
present = tf.stack([key, value], axis=1)
|
|
|
|
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
|
|
a = attn_outputs[0]
|
|
|
|
a = self.merge_heads(a)
|
|
a = self.c_proj(a)
|
|
if training:
|
|
a = self.resid_dropout(a)
|
|
|
|
outputs = [a, present] + attn_outputs[1:]
|
|
return outputs # a, present, (attentions)
|
|
|
|
|
|
class TFMLP(tf.keras.layers.Layer):
|
|
def __init__(self, n_state, config, **kwargs):
|
|
super(TFMLP, self).__init__(**kwargs)
|
|
nx = config.n_embd
|
|
self.c_fc = TFConv1D(n_state, nx, name='c_fc')
|
|
self.c_proj = TFConv1D(nx, n_state, name='c_proj')
|
|
self.act = gelu
|
|
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
|
|
|
def call(self, x, training=False):
|
|
h = self.act(self.c_fc(x))
|
|
h2 = self.c_proj(h)
|
|
if training:
|
|
h2 = self.dropout(h2)
|
|
return h2
|
|
|
|
|
|
class TFBlock(tf.keras.layers.Layer):
|
|
def __init__(self, n_ctx, config, scale=False, **kwargs):
|
|
super(TFBlock, self).__init__(**kwargs)
|
|
nx = config.n_embd
|
|
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_1')
|
|
self.attn = TFAttention(nx, n_ctx, config, scale, name='attn')
|
|
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
|
|
self.mlp = TFMLP(4 * nx, config, name='mlp')
|
|
|
|
def call(self, inputs, training=False):
|
|
x, layer_past, attention_mask, head_mask = inputs
|
|
|
|
a = self.ln_1(x)
|
|
output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training)
|
|
a = output_attn[0] # output_attn: a, present, (attentions)
|
|
x = x + a
|
|
|
|
m = self.ln_2(x)
|
|
m = self.mlp(m, training=training)
|
|
x = x + m
|
|
|
|
outputs = [x] + output_attn[1:]
|
|
return outputs # x, present, (attentions)
|
|
|
|
class TFGPT2Embeddings(tf.keras.layers.Layer):
|
|
"""Construct the embeddings from word, position and token_type embeddings.
|
|
"""
|
|
def __init__(self, config, **kwargs):
|
|
super(TFGPT2Embeddings, self).__init__(**kwargs)
|
|
self.vocab_size = config.vocab_size
|
|
self.hidden_size = config.hidden_size
|
|
|
|
def build(self, input_shape):
|
|
"""Build shared word embedding layer
|
|
Shared weights logic adapted from
|
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
|
"""
|
|
self.weight = self.add_weight(
|
|
"weight",
|
|
shape=[self.vocab_size, self.hidden_size],
|
|
initializer=tf.random_normal_initializer(
|
|
mean=0., stddev=self.hidden_size**-0.5))
|
|
super(TFGPT2Embeddings, self).build(input_shape)
|
|
|
|
def call(self, inputs, mode="embedding"):
|
|
"""Get token embeddings of inputs.
|
|
Args:
|
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
|
mode: string, a valid value is one of "embedding" and "linear".
|
|
Returns:
|
|
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
|
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
|
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
|
Raises:
|
|
ValueError: if mode is not valid.
|
|
|
|
Shared weights logic adapted from
|
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
|
"""
|
|
if mode == "embedding":
|
|
return self._embedding(inputs)
|
|
elif mode == "linear":
|
|
return self._linear(inputs)
|
|
else:
|
|
raise ValueError("mode {} is not valid.".format(mode))
|
|
|
|
def _embedding(self, input_ids):
|
|
"""Applies embedding based on inputs tensor."""
|
|
return tf.gather(self.weight, input_ids)
|
|
|
|
def _linear(self, inputs):
|
|
"""Computes logits by running inputs through a linear layer.
|
|
Args:
|
|
inputs: A float32 tensor with shape [..., hidden_size]
|
|
Returns:
|
|
float32 tensor with shape [..., vocab_size].
|
|
"""
|
|
first_dims = shape_list(inputs)[:-1]
|
|
|
|
x = tf.reshape(inputs, [-1, self.hidden_size])
|
|
logits = tf.matmul(x, self.weight, transpose_b=True)
|
|
|
|
return tf.reshape(logits, first_dims + [self.vocab_size])
|
|
|
|
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFGPT2MainLayer, self).__init__(config, *inputs, **kwargs)
|
|
self.output_hidden_states = config.output_hidden_states
|
|
self.output_attentions = config.output_attentions
|
|
self.num_hidden_layers = config.n_layer
|
|
self.vocab_size = config.vocab_size
|
|
self.n_embd = config.n_embd
|
|
|
|
self.wte = TFGPT2Embeddings(config, name='wte')
|
|
self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
|
|
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
|
self.h = [TFBlock(config.n_ctx, config, scale=True, name='h_{}'.format(i)) for i in range(config.n_layer)]
|
|
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
|
|
|
|
def _resize_token_embeddings(self, new_num_tokens):
|
|
raise NotImplementedError
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
""" Prunes heads of the model.
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def call(self, inputs, training=False):
|
|
if not isinstance(inputs, (dict, tuple, list)):
|
|
input_ids = inputs
|
|
past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
|
|
elif isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
past = inputs[1] if len(inputs) > 1 else None
|
|
attention_mask = inputs[2] if len(inputs) > 2 else None
|
|
token_type_ids = inputs[3] if len(inputs) > 3 else None
|
|
position_ids = inputs[4] if len(inputs) > 4 else None
|
|
head_mask = inputs[5] if len(inputs) > 5 else None
|
|
assert len(inputs) <= 6, "Too many inputs."
|
|
else:
|
|
input_ids = inputs.get('input_ids')
|
|
past = inputs.get('past', None)
|
|
attention_mask = inputs.get('attention_mask', None)
|
|
token_type_ids = inputs.get('token_type_ids', None)
|
|
position_ids = inputs.get('position_ids', None)
|
|
head_mask = inputs.get('head_mask', None)
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
|
|
if past is None:
|
|
past_length = 0
|
|
past = [None] * len(self.h)
|
|
else:
|
|
past_length = shape_list(past[0][0])[-2]
|
|
if position_ids is None:
|
|
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
|
|
|
if attention_mask is not None:
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# this attention mask is more simple than the triangular masking of causal attention
|
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
|
|
attention_mask = tf.cast(attention_mask, tf.float32)
|
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
|
else:
|
|
attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if not head_mask is None:
|
|
raise NotImplementedError
|
|
else:
|
|
head_mask = [None] * self.num_hidden_layers
|
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
|
|
input_shape = shape_list(input_ids)
|
|
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
|
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
|
|
|
inputs_embeds = self.wte(input_ids, mode='embedding')
|
|
position_embeds = self.wpe(position_ids)
|
|
if token_type_ids is not None:
|
|
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
|
token_type_embeds = self.wte(token_type_ids, mode='embedding')
|
|
else:
|
|
token_type_embeds = 0
|
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
|
if training:
|
|
hidden_states = self.drop(hidden_states)
|
|
|
|
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
|
|
|
presents = ()
|
|
all_attentions = []
|
|
all_hidden_states = ()
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
|
|
|
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training)
|
|
|
|
hidden_states, present = outputs[:2]
|
|
presents = presents + (present,)
|
|
|
|
if self.output_attentions:
|
|
all_attentions.append(outputs[2])
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
hidden_states = tf.reshape(hidden_states, output_shape)
|
|
# Add last hidden state
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
outputs = (hidden_states, presents)
|
|
if self.output_hidden_states:
|
|
outputs = outputs + (all_hidden_states,)
|
|
if self.output_attentions:
|
|
# let the number of heads free (-1) so we can extract attention even after head pruning
|
|
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
|
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
|
outputs = outputs + (all_attentions,)
|
|
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
|
|
|
|
|
class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
config_class = GPT2Config
|
|
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
load_pt_weights = load_gpt2_pt_weights_in_tf2
|
|
base_model_prefix = "transformer"
|
|
|
|
|
|
GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
|
|
`Language Models are Unsupervised Multitask Learners`_
|
|
by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
|
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
|
|
corpus of ~40 GB of text data.
|
|
|
|
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
|
|
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
|
|
|
.. _`Language Models are Unsupervised Multitask Learners`:
|
|
https://openai.com/blog/better-language-models/
|
|
|
|
.. _`tf.keras.Model`:
|
|
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
|
|
|
|
Important note on the model inputs:
|
|
The inputs of the TF 2.0 models are slightly different from the PyTorch ones since
|
|
TF 2.0 Keras doesn't accept named arguments with defaults values for input Tensor.
|
|
More precisely, input Tensors are gathered in the first arguments of the model call function: `model(inputs)`.
|
|
There are three possibilities to gather and feed the inputs to the model:
|
|
|
|
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
|
|
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
|
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
|
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
|
|
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
|
|
|
Parameters:
|
|
config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
|
the right rather than the left.
|
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**past**:
|
|
list of ``torch.FloatTensor`` (one for each layer):
|
|
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
|
(see `past` output below). Can be used to speed up sequential decoding.
|
|
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
|
The embeddings from these tokens will be summed with the respective token embeddings.
|
|
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
|
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
"""
|
|
|
|
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
|
|
GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
|
class TFGPT2Model(TFGPT2PreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
Sequence of hidden-states at the last layer of the model.
|
|
**past**:
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
that contains pre-computed hidden-states (key and values in the attention blocks).
|
|
Can be used (see `past` input) to speed up sequential decoding.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = GPT2Model.from_pretrained('gpt2')
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids)
|
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
|
|
|
def call(self, inputs, training=False):
|
|
outputs = self.transformer(inputs, training=training)
|
|
return outputs
|
|
|
|
|
|
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
|
|
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
|
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**past**:
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
that contains pre-computed hidden-states (key and values in the attention blocks).
|
|
Can be used (see `past` input) to speed up sequential decoding.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import torch
|
|
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
|
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids)
|
|
logits = outputs[:2]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
|
|
|
def call(self, inputs, training=False):
|
|
transformer_outputs = self.transformer(inputs, training=training)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
|
|
|
outputs = (lm_logits,) + transformer_outputs[1:]
|
|
|
|
return outputs # lm_logits, presents, (all hidden_states), (attentions)
|
|
|
|
|
|
@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
|
|
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
|
The language modeling head has its weights tied to the input embeddings,
|
|
the classification head takes as input the input of a specified classification token index in the input sequence).
|
|
""", GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
|
class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|
r"""
|
|
**mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
|
|
Index of the classification token in each input sequence.
|
|
Selected in the range ``[0, input_ids.size(-1) - 1[``.
|
|
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
|
|
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
|
|
**past**:
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
that contains pre-computed hidden-states (key and values in the attention blocks).
|
|
Can be used (see `past` input) to speed up sequential decoding.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
import torch
|
|
from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
|
|
|
|
# Add a [CLS] to the vocabulary (we should train it also!)
|
|
tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
|
model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
|
|
print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
|
|
|
|
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
|
encoded_choices = [tokenizer.encode(s) for s in choices]
|
|
cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
|
|
|
|
input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
|
|
mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
|
|
|
|
outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
|
lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
|
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFGPT2DoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
|
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
|
|
|
|
|
|
def call(self, inputs, training=False):
|
|
if not isinstance(inputs, (dict, tuple, list)):
|
|
input_ids = inputs
|
|
mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
|
|
elif isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
mc_token_ids = inputs[1] if len(inputs) > 1 else None
|
|
past = inputs[2] if len(inputs) > 2 else None
|
|
attention_mask = inputs[3] if len(inputs) > 3 else None
|
|
token_type_ids = inputs[4] if len(inputs) > 4 else None
|
|
position_ids = inputs[5] if len(inputs) > 5 else None
|
|
head_mask = inputs[6] if len(inputs) > 6 else None
|
|
assert len(inputs) <= 7, "Too many inputs."
|
|
else:
|
|
input_ids = inputs.get('input_ids')
|
|
mc_token_ids = inputs.get('mc_token_ids', None)
|
|
past = inputs.get('past', None)
|
|
attention_mask = inputs.get('attention_mask', None)
|
|
token_type_ids = inputs.get('token_type_ids', None)
|
|
position_ids = inputs.get('position_ids', None)
|
|
head_mask = inputs.get('head_mask', None)
|
|
assert len(inputs) <= 5, "Too many inputs."
|
|
|
|
input_shapes = shape_list(input_ids)
|
|
|
|
seq_length = input_shapes[-1]
|
|
|
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
|
|
|
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
|
|
|
transformer_outputs = self.transformer(flat_inputs, training=training)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
|
|
|
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
|
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
|
|
|
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
|
|
|
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
|
|
|
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
|