diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 2a1d3f1c4d8..c01e7e93311 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -29,14 +29,14 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { - "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-tf_model.h5", - "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-tf_model.h5", - "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-tf_model.h5", - "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-tf_model.h5", - "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5", - "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5", - "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5", - "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-tf_model.h5", + "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v1-with-prefix-tf_model.h5", + "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v1-with-prefix-tf_model.h5", + "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v1-with-prefix-tf_model.h5", + "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v1-with-prefix-tf_model.h5", + "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-with-prefix-tf_model.h5", + "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-with-prefix-tf_model.h5", + "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-with-prefix-tf_model.h5", + "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-with-prefix-tf_model.h5", } @@ -478,6 +478,115 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): return hidden_states +class TFAlbertMainLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.num_hidden_layers = config.num_hidden_layers + + self.embeddings = TFAlbertEmbeddings(config, name="embeddings") + self.encoder = TFAlbertTransformer(config, name="encoder") + self.pooler = tf.keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="pooler", + ) + + def get_input_embeddings(self): + return self.embeddings + + def _resize_token_embeddings(self, new_num_tokens): + raise NotImplementedError + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + raise NotImplementedError + + def call( + self, + inputs, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + training=False, + ): + if isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else attention_mask + token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids + position_ids = inputs[3] if len(inputs) > 3 else position_ids + head_mask = inputs[4] if len(inputs) > 4 else head_mask + inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds + assert len(inputs) <= 6, "Too many inputs." + elif isinstance(inputs, dict): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", attention_mask) + token_type_ids = inputs.get("token_type_ids", token_type_ids) + position_ids = inputs.get("position_ids", position_ids) + head_mask = inputs.get("head_mask", head_mask) + inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) + assert len(inputs) <= 6, "Too many inputs." + else: + input_ids = inputs + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) + if token_type_ids is None: + token_type_ids = tf.fill(input_shape, 0) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) + encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output[:, 0]) + + # add hidden_states and attentions if they are here + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] + # sequence_output, pooled_output, (hidden_states), (attentions) + return outputs + + ALBERT_START_DOCSTRING = r""" This model is a `tf.keras.Model `__ sub-class. Use it as a regular TF 2.0 Keras Model and @@ -560,147 +669,48 @@ ALBERT_INPUTS_DOCSTRING = r""" ALBERT_START_DOCSTRING, ) class TFAlbertModel(TFAlbertPreTrainedModel): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.num_hidden_layers = config.num_hidden_layers - - self.embeddings = TFAlbertEmbeddings(config, name="embeddings") - self.encoder = TFAlbertTransformer(config, name="encoder") - self.pooler = tf.keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="pooler", - ) - - def get_input_embeddings(self): - return self.embeddings - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError - - def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel - """ - raise NotImplementedError + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.albert = TFAlbertMainLayer(config, name="albert") @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) - def call( - self, - inputs, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - training=False, - ): + def call(self, inputs, **kwargs): r""" - Returns: - :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: - last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) - further processed by a Linear layer and a Tanh activation function. The Linear - layer weights are trained from the next sentence prediction (classification) - objective during Albert pretraining. This output is usually *not* a good summary - of the semantic content of the input, you're often better with averaging or pooling - the sequence of hidden-states for the whole input sequence. - hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): - tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Returns: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: + last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during Albert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): - tuple of :obj:`tf.Tensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - Examples:: + Examples:: - import tensorflow as tf - from transformers import AlbertTokenizer, TFAlbertModel + import tensorflow as tf + from transformers import AlbertTokenizer, TFAlbertModel - tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') - model = TFAlbertModel.from_pretrained('albert-base-v2') - input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 - outputs = model(input_ids) - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') + model = TFAlbertModel.from_pretrained('albert-base-v2') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple """ - if isinstance(inputs, (tuple, list)): - input_ids = inputs[0] - attention_mask = inputs[1] if len(inputs) > 1 else attention_mask - token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids - position_ids = inputs[3] if len(inputs) > 3 else position_ids - head_mask = inputs[4] if len(inputs) > 4 else head_mask - inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds - assert len(inputs) <= 6, "Too many inputs." - elif isinstance(inputs, dict): - input_ids = inputs.get("input_ids") - attention_mask = inputs.get("attention_mask", attention_mask) - token_type_ids = inputs.get("token_type_ids", token_type_ids) - position_ids = inputs.get("position_ids", position_ids) - head_mask = inputs.get("head_mask", head_mask) - inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) - assert len(inputs) <= 6, "Too many inputs." - else: - input_ids = inputs - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - if token_type_ids is None: - token_type_ids = tf.fill(input_shape, 0) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - # head_mask = tf.constant([0] * self.num_hidden_layers) - - embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) - encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output[:, 0]) - - # add hidden_states and attentions if they are here - outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] - # sequence_output, pooled_output, (hidden_states), (attentions) + outputs = self.albert(inputs, **kwargs) return outputs @@ -709,7 +719,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): def __init__(self, config, *inputs, **kwargs): super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs) - self.albert = TFAlbertModel(config, name="albert") + self.albert = TFAlbertMainLayer(config, name="albert") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") def get_output_embeddings(self): @@ -766,7 +776,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel): super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels - self.albert = TFAlbertModel(config, name="albert") + self.albert = TFAlbertMainLayer(config, name="albert") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"