mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
[TF models] Common attributes as per #1721
This commit is contained in:
parent
872403be1c
commit
70d97ddd60
@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
self.encoder = TFBertEncoder(config, name='encoder')
|
self.encoder = TFBertEncoder(config, name='encoder')
|
||||||
self.pooler = TFBertPooler(config, name='pooler')
|
self.pooler = TFBertPooler(config, name='pooler')
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
|||||||
self.nsp = TFBertNSPHead(config, name='nsp___cls')
|
self.nsp = TFBertNSPHead(config, name='nsp___cls')
|
||||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.bert.embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.bert(inputs, **kwargs)
|
outputs = self.bert(inputs, **kwargs)
|
||||||
|
|
||||||
@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
|||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.bert.embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.bert(inputs, **kwargs)
|
outputs = self.bert(inputs, **kwargs)
|
||||||
|
|
||||||
|
@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
|
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.w
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
|||||||
|
|
||||||
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head.input_embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
@ -398,6 +398,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||||||
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
||||||
self.transformer = TFTransformer(config, name="transformer") # Encoder
|
self.transformer = TFTransformer(config, name="transformer") # Encoder
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -613,6 +616,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
|||||||
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
|
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
|
||||||
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
|
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.vocab_projector.input_embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
distilbert_output = self.distilbert(inputs, **kwargs)
|
distilbert_output = self.distilbert(inputs, **kwargs)
|
||||||
|
|
||||||
|
@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
|
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.wte
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
|||||||
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.transformer.wte
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||||
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.transformer.wte
|
||||||
|
|
||||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
|
@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
scale=True,
|
scale=True,
|
||||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.tokens_embed
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -462,6 +465,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.transformer.tokens_embed
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
@ -524,6 +530,9 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
||||||
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.transformer.tokens_embed
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
|
@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
|
|||||||
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
|
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
|
||||||
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
|
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
|
|||||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||||
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
|
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head.decoder
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.roberta(inputs, **kwargs)
|
outputs = self.roberta(inputs, **kwargs)
|
||||||
|
|
||||||
|
@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
name='r_r_bias')
|
name='r_r_bias')
|
||||||
super(TFTransfoXLMainLayer, self).build(input_shape)
|
super(TFTransfoXLMainLayer, self).build(input_shape)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.word_emb
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
return self.word_emb
|
return self.word_emb
|
||||||
|
|
||||||
|
@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
# Save config in model
|
# Save config in model
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
""" Get model's input embeddings
|
||||||
|
"""
|
||||||
|
base_model = getattr(self, self.base_model_prefix, self)
|
||||||
|
if base_model is not self:
|
||||||
|
return base_model.get_input_embeddings()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
""" Get model's output embeddings
|
||||||
|
Return None if the model doesn't have output embeddings
|
||||||
|
"""
|
||||||
|
return None # Overwrite for models with output embeddings
|
||||||
|
|
||||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||||
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
||||||
Increasing the size will add newly initialized vectors at the end
|
Increasing the size will add newly initialized vectors at the end
|
||||||
|
@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -641,6 +644,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
|||||||
self.transformer = TFXLMMainLayer(config, name='transformer')
|
self.transformer = TFXLMMainLayer(config, name='transformer')
|
||||||
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
|
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.pred_layer.input_embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||||
|
@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
|
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.word_embedding
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
initializer = get_initializer(self.initializer_range)
|
initializer = get_initializer(self.initializer_range)
|
||||||
self.mask_emb = self.add_weight(shape=(1, 1, self.d_model),
|
self.mask_emb = self.add_weight(shape=(1, 1, self.d_model),
|
||||||
@ -854,6 +857,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
|||||||
self.transformer = TFXLNetMainLayer(config, name='transformer')
|
self.transformer = TFXLNetMainLayer(config, name='transformer')
|
||||||
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
|
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_loss.input_embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||||
hidden_state = transformer_outputs[0]
|
hidden_state = transformer_outputs[0]
|
||||||
|
@ -360,6 +360,16 @@ class TFCommonTestCases:
|
|||||||
# self.assertTrue(models_equal)
|
# self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
assert x is None or instanceof(x, tf.keras.layers.Layer)
|
||||||
|
|
||||||
|
|
||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
Loading…
Reference in New Issue
Block a user