mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Patch ALBERT with heads in TensorFlow
This commit is contained in:
parent
e676764241
commit
1abd53b1aa
@ -29,14 +29,14 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-tf_model.h5",
|
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-with-prefix-tf_model.h5",
|
||||||
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-tf_model.h5",
|
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-with-prefix-tf_model.h5",
|
||||||
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-tf_model.h5",
|
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-with-prefix-tf_model.h5",
|
||||||
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-tf_model.h5",
|
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-with-prefix-tf_model.h5",
|
||||||
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5",
|
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-with-prefix-tf_model.h5",
|
||||||
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5",
|
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-with-prefix-tf_model.h5",
|
||||||
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5",
|
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-with-prefix-tf_model.h5",
|
||||||
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-tf_model.h5",
|
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-with-prefix-tf_model.h5",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -478,6 +478,115 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, config, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
||||||
|
self.encoder = TFAlbertTransformer(config, name="encoder")
|
||||||
|
self.pooler = tf.keras.layers.Dense(
|
||||||
|
config.hidden_size,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
activation="tanh",
|
||||||
|
name="pooler",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
""" Prunes heads of the model.
|
||||||
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||||
|
See base class PreTrainedModel
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
|
elif isinstance(inputs, dict):
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = shape_list(input_ids)
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = tf.fill(input_shape, 1)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
|
||||||
|
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
if head_mask is not None:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
head_mask = [None] * self.num_hidden_layers
|
||||||
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
|
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||||
|
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
||||||
|
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output[:, 0])
|
||||||
|
|
||||||
|
# add hidden_states and attentions if they are here
|
||||||
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
||||||
|
# sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
ALBERT_START_DOCSTRING = r"""
|
ALBERT_START_DOCSTRING = r"""
|
||||||
This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
|
This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
|
||||||
Use it as a regular TF 2.0 Keras Model and
|
Use it as a regular TF 2.0 Keras Model and
|
||||||
@ -560,147 +669,48 @@ ALBERT_INPUTS_DOCSTRING = r"""
|
|||||||
ALBERT_START_DOCSTRING,
|
ALBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFAlbertModel(TFAlbertPreTrainedModel):
|
class TFAlbertModel(TFAlbertPreTrainedModel):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||||
|
|
||||||
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
|
||||||
self.encoder = TFAlbertTransformer(config, name="encoder")
|
|
||||||
self.pooler = tf.keras.layers.Dense(
|
|
||||||
config.hidden_size,
|
|
||||||
kernel_initializer=get_initializer(config.initializer_range),
|
|
||||||
activation="tanh",
|
|
||||||
name="pooler",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.embeddings
|
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
|
||||||
""" Prunes heads of the model.
|
|
||||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
||||||
See base class PreTrainedModel
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(self, inputs, **kwargs):
|
||||||
self,
|
|
||||||
inputs,
|
|
||||||
attention_mask=None,
|
|
||||||
token_type_ids=None,
|
|
||||||
position_ids=None,
|
|
||||||
head_mask=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
training=False,
|
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||||
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
|
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
Last layer hidden-state of the first token of the sequence (classification token)
|
Last layer hidden-state of the first token of the sequence (classification token)
|
||||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||||
layer weights are trained from the next sentence prediction (classification)
|
layer weights are trained from the next sentence prediction (classification)
|
||||||
objective during Albert pretraining. This output is usually *not* a good summary
|
objective during Albert pretraining. This output is usually *not* a good summary
|
||||||
of the semantic content of the input, you're often better with averaging or pooling
|
of the semantic content of the input, you're often better with averaging or pooling
|
||||||
the sequence of hidden-states for the whole input sequence.
|
the sequence of hidden-states for the whole input sequence.
|
||||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from transformers import AlbertTokenizer, TFAlbertModel
|
from transformers import AlbertTokenizer, TFAlbertModel
|
||||||
|
|
||||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
model = TFAlbertModel.from_pretrained('albert-base-v2')
|
model = TFAlbertModel.from_pretrained('albert-base-v2')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
outputs = model(input_ids)
|
||||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(inputs, (tuple, list)):
|
outputs = self.albert(inputs, **kwargs)
|
||||||
input_ids = inputs[0]
|
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
||||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
|
||||||
elif isinstance(inputs, dict):
|
|
||||||
input_ids = inputs.get("input_ids")
|
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
|
||||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
|
||||||
position_ids = inputs.get("position_ids", position_ids)
|
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
|
||||||
else:
|
|
||||||
input_ids = inputs
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = shape_list(input_ids)
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = shape_list(inputs_embeds)[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = tf.fill(input_shape, 1)
|
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = tf.fill(input_shape, 0)
|
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
|
|
||||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
head_mask = [None] * self.num_hidden_layers
|
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
||||||
|
|
||||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
|
||||||
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
|
||||||
|
|
||||||
sequence_output = encoder_outputs[0]
|
|
||||||
pooled_output = self.pooler(sequence_output[:, 0])
|
|
||||||
|
|
||||||
# add hidden_states and attentions if they are here
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
|
||||||
# sequence_output, pooled_output, (hidden_states), (attentions)
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -709,7 +719,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
|||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.albert = TFAlbertModel(config, name="albert")
|
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||||
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
|
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
@ -766,7 +776,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
|||||||
super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.albert = TFAlbertModel(config, name="albert")
|
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(
|
self.classifier = tf.keras.layers.Dense(
|
||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
Loading…
Reference in New Issue
Block a user