[TF models] Common attributes as per #1721

This commit is contained in:
Julien Chaumond 2019-11-11 16:30:22 -05:00
parent 872403be1c
commit 70d97ddd60
11 changed files with 84 additions and 0 deletions

View File

@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.encoder = TFBertEncoder(config, name='encoder')
self.pooler = TFBertPooler(config, name='pooler')
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.nsp = TFBertNSPHead(config, name='nsp___cls')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def get_output_embeddings(self):
return self.bert.embeddings
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def get_output_embeddings(self):
return self.bert.embeddings
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)

View File

@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
def get_input_embeddings(self):
return self.w
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.input_embeddings
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]

View File

@ -398,6 +398,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
self.transformer = TFTransformer(config, name="transformer") # Encoder
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -613,6 +616,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def get_output_embeddings(self):
return self.vocab_projector.input_embeddings
def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, **kwargs)

View File

@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def get_input_embeddings(self):
return self.wte
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer')
def get_output_embeddings(self):
return self.transformer.wte
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]
@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self.transformer = TFGPT2MainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
def get_output_embeddings(self):
return self.transformer.wte
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]

View File

@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
scale=True,
name='h_._{}'.format(i)) for i in range(config.n_layer)]
def get_input_embeddings(self):
return self.tokens_embed
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -462,6 +465,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
def get_output_embeddings(self):
return self.transformer.tokens_embed
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]
@ -524,6 +530,9 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
def get_output_embeddings(self):
return self.transformer.tokens_embed
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]

View File

@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
def get_input_embeddings(self):
return self.embeddings
class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.decoder
def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs)

View File

@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
name='r_r_bias')
super(TFTransfoXLMainLayer, self).build(input_shape)
def get_input_embeddings(self):
return self.word_emb
def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb

View File

@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
# Save config in model
self.config = config
def get_input_embeddings(self):
""" Get model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end

View File

@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self.prune_heads({int(layer): list(map(int, heads))})
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -641,6 +644,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
self.transformer = TFXLMMainLayer(config, name='transformer')
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
def get_output_embeddings(self):
return self.pred_layer.input_embeddings
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)

View File

@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_input_embeddings(self):
return self.word_embedding
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(shape=(1, 1, self.d_model),
@ -854,6 +857,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
def get_output_embeddings(self):
return self.lm_loss.input_embeddings
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_state = transformer_outputs[0]

View File

@ -360,6 +360,16 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_embeddings()
assert x is None or instanceof(x, tf.keras.layers.Layer)
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()