mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
TF: CTRL with native embedding layers (#23456)
This commit is contained in:
parent
eac8dede83
commit
4626df5077
@ -15,10 +15,8 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 CTRL model."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
|
||||
|
||||
self.w = TFSharedEmbeddings(
|
||||
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="w"
|
||||
self.w = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.n_embd,
|
||||
embeddings_initializer=get_initializer(config.initializer_range),
|
||||
name="w",
|
||||
)
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||
@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
def get_input_embeddings(self):
|
||||
return self.w
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.w.weight = value
|
||||
self.w.vocab_size = shape_list(value)[0]
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.w = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
||||
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
token_type_embeds = self.w(token_type_ids, mode="embedding")
|
||||
token_type_embeds = self.w(token_type_ids)
|
||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
|
||||
else:
|
||||
token_type_embeds = tf.constant(0.0)
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
if inputs_embeds is None:
|
||||
check_embeddings_within_bounds(input_ids, self.w.vocab_size)
|
||||
inputs_embeds = self.w(input_ids, mode="embedding")
|
||||
check_embeddings_within_bounds(input_ids, self.w.input_dim)
|
||||
inputs_embeds = self.w(input_ids)
|
||||
seq_len = input_shape[-1]
|
||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||
|
||||
@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
class TFCTRLLMHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
# CTRL has numerical issues in XLA generate
|
||||
self.supports_xla_generation = False
|
||||
class TFCTRLBiasLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
|
||||
so all weights have to be registered in a layer.
|
||||
"""
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.input_embeddings = input_embeddings
|
||||
def __init__(self, shape, initializer, trainable, name, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self.shape = shape
|
||||
self.initializer = initializer
|
||||
self.trainable = trainable
|
||||
|
||||
def build(self, input_shape=None):
|
||||
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(
|
||||
name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable
|
||||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.input_embeddings.weight = value
|
||||
self.input_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def get_bias(self):
|
||||
return {"bias": self.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.bias = value["bias"]
|
||||
self.config.vocab_size = shape_list(value["bias"])[0]
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
return hidden_states
|
||||
def call(self, x):
|
||||
return x + self.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
||||
self.bias_layer = TFCTRLBiasLayer(
|
||||
name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True
|
||||
)
|
||||
|
||||
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
||||
# CTRL has numerical issues in XLA generate
|
||||
self.supports_xla_generation = False
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.lm_head
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.lm_head.name
|
||||
def get_bias(self):
|
||||
return {"lm_head.bias": self.bias_layer.bias}
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
|
||||
def set_bias(self, value):
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["lm_head.bias"].shape[-1]
|
||||
self.bias_layer = TFCTRLBiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True
|
||||
)
|
||||
self.bias_layer.build(None)
|
||||
self.bias_layer.bias.assign(value["lm_head.bias"])
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
input_ids = tf.expand_dims(input_ids[:, -1], -1)
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||
if past_key_values:
|
||||
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True)
|
||||
logits = self.bias_layer(logits)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.build() # may be needed for the get_bias() call below
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
|
||||
if model_class in list_lm_models:
|
||||
|
Loading…
Reference in New Issue
Block a user