Patch ALBERT with heads in TensorFlow

This commit is contained in:
Lysandre 2020-02-19 18:24:11 -05:00
parent e676764241
commit 1abd53b1aa

View File

@ -29,14 +29,14 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-tf_model.h5", "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-with-prefix-tf_model.h5",
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-tf_model.h5", "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-with-prefix-tf_model.h5",
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-tf_model.h5", "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-with-prefix-tf_model.h5",
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-tf_model.h5", "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-with-prefix-tf_model.h5",
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5", "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-with-prefix-tf_model.h5",
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5", "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-with-prefix-tf_model.h5",
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5", "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-with-prefix-tf_model.h5",
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-tf_model.h5", "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-with-prefix-tf_model.h5",
} }
@ -478,6 +478,115 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFAlbertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder")
self.pooler = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="pooler",
)
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
raise NotImplementedError
def call(
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output[:, 0])
# add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
# sequence_output, pooled_output, (hidden_states), (attentions)
return outputs
ALBERT_START_DOCSTRING = r""" ALBERT_START_DOCSTRING = r"""
This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class. This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
Use it as a regular TF 2.0 Keras Model and Use it as a regular TF 2.0 Keras Model and
@ -560,147 +669,48 @@ ALBERT_INPUTS_DOCSTRING = r"""
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class TFAlbertModel(TFAlbertPreTrainedModel): class TFAlbertModel(TFAlbertPreTrainedModel):
def __init__(self, config, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_hidden_layers = config.num_hidden_layers self.albert = TFAlbertMainLayer(config, name="albert")
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder")
self.pooler = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="pooler",
)
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
raise NotImplementedError
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call( def call(self, inputs, **kwargs):
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
training=False,
):
r""" r"""
Returns: Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`): pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification) layer weights are trained from the next sentence prediction (classification)
objective during Albert pretraining. This output is usually *not* a good summary objective during Albert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence. the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:: Examples::
import tensorflow as tf import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertModel from transformers import AlbertTokenizer, TFAlbertModel
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertModel.from_pretrained('albert-base-v2') model = TFAlbertModel.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
if isinstance(inputs, (tuple, list)): outputs = self.albert(inputs, **kwargs)
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output[:, 0])
# add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
# sequence_output, pooled_output, (hidden_states), (attentions)
return outputs return outputs
@ -709,7 +719,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs) super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs)
self.albert = TFAlbertModel(config, name="albert") self.albert = TFAlbertMainLayer(config, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
def get_output_embeddings(self): def get_output_embeddings(self):
@ -766,7 +776,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs) super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.albert = TFAlbertModel(config, name="albert") self.albert = TFAlbertMainLayer(config, name="albert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense( self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"