TF: CTRL with native embedding layers (#23456)

This commit is contained in:
Joao Gante 2023-06-14 14:39:02 +01:00 committed by GitHub
parent eac8dede83
commit 4626df5077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 55 deletions

View File

@ -15,10 +15,8 @@
# limitations under the License.
""" TF 2.0 CTRL model."""
from __future__ import annotations
import warnings
from typing import Optional, Tuple, Union
import numpy as np
@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
TFSequenceClassificationLoss,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
unpack_inputs,
@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
self.w = TFSharedEmbeddings(
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="w"
self.w = tf.keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.n_embd,
embeddings_initializer=get_initializer(config.initializer_range),
name="w",
)
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.w
def set_input_embeddings(self, value):
self.w.weight = value
self.w.vocab_size = shape_list(value)[0]
def set_input_embeddings(self, new_embeddings):
self.w = new_embeddings
def _prune_heads(self, heads_to_prune):
"""
@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.w(token_type_ids, mode="embedding")
token_type_embeds = self.w(token_type_ids)
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
else:
token_type_embeds = tf.constant(0.0)
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs_embeds is None:
check_embeddings_within_bounds(input_ids, self.w.vocab_size)
inputs_embeds = self.w(input_ids, mode="embedding")
check_embeddings_within_bounds(input_ids, self.w.input_dim)
inputs_embeds = self.w(input_ids)
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
return outputs
class TFCTRLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.config = config
# CTRL has numerical issues in XLA generate
self.supports_xla_generation = False
class TFCTRLBiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
self.shape = shape
self.initializer = initializer
self.trainable = trainable
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
def build(self, input_shape):
self.bias = self.add_weight(
name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable
)
super().build(input_shape)
def get_output_embeddings(self):
return self.input_embeddings
def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.config.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
def call(self, x):
return x + self.bias
@add_start_docstrings(
@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFCTRLMainLayer(config, name="transformer")
self.bias_layer = TFCTRLBiasLayer(
name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True
)
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
# CTRL has numerical issues in XLA generate
self.supports_xla_generation = False
def get_output_embeddings(self):
return self.get_input_embeddings()
def get_lm_head(self):
return self.lm_head
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
def get_bias(self):
return {"lm_head.bias": self.bias_layer.bias}
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
def set_bias(self, value):
# Replaces the existing layers containing bias for correct (de)serialization.
vocab_size = value["lm_head.bias"].shape[-1]
self.bias_layer = TFCTRLBiasLayer(
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True
)
self.bias_layer.build(None)
self.bias_layer.bias.assign(value["lm_head.bias"])
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = tf.expand_dims(input_ids[:, -1], -1)
inputs = tf.expand_dims(inputs[:, -1], -1)
if token_type_ids is not None:
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
if past_key_values:
position_ids = tf.expand_dims(position_ids[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
}
@unpack_inputs
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=return_dict,
training=training,
)
hidden_states = transformer_outputs[0]
logits = self.lm_head(hidden_states)
logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True)
logits = self.bias_layer(logits)
loss = None
if labels is not None:

View File

@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for model_class in self.all_model_classes:
model = model_class(config)
model.build() # may be needed for the get_bias() call below
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class in list_lm_models: