mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
[TF models] Common attributes as per #1721
This commit is contained in:
parent
872403be1c
commit
70d97ddd60
@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
self.encoder = TFBertEncoder(config, name='encoder')
|
||||
self.pooler = TFBertPooler(config, name='pooler')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
self.nsp = TFBertNSPHead(config, name='nsp___cls')
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
|
@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.w
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
|
||||
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.input_embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
@ -398,6 +398,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
||||
self.transformer = TFTransformer(config, name="transformer") # Encoder
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -613,6 +616,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
||||
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
|
||||
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.vocab_projector.input_embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
distilbert_output = self.distilbert(inputs, **kwargs)
|
||||
|
||||
|
@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
||||
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
|
@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
scale=True,
|
||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.tokens_embed
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -462,6 +465,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
|
||||
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.tokens_embed
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
@ -524,6 +530,9 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
|
||||
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.tokens_embed
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
|
@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
|
||||
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
|
||||
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.roberta(inputs, **kwargs)
|
||||
|
||||
|
@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
name='r_r_bias')
|
||||
super(TFTransfoXLMainLayer, self).build(input_shape)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.word_emb
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
return self.word_emb
|
||||
|
||||
|
@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
|
@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -641,6 +644,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
||||
self.transformer = TFXLMMainLayer(config, name='transformer')
|
||||
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.pred_layer.input_embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
|
@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.word_embedding
|
||||
|
||||
def build(self, input_shape):
|
||||
initializer = get_initializer(self.initializer_range)
|
||||
self.mask_emb = self.add_weight(shape=(1, 1, self.d_model),
|
||||
@ -854,6 +857,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
||||
self.transformer = TFXLNetMainLayer(config, name='transformer')
|
||||
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss.input_embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
hidden_state = transformer_outputs[0]
|
||||
|
@ -360,6 +360,16 @@ class TFCommonTestCases:
|
||||
# self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None or instanceof(x, tf.keras.layers.Layer)
|
||||
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user