diff --git a/transformers/modeling_tf_bert.py b/transformers/modeling_tf_bert.py index a1275db9747..66d5efd87c9 100644 --- a/transformers/modeling_tf_bert.py +++ b/transformers/modeling_tf_bert.py @@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): self.encoder = TFBertEncoder(config, name='encoder') self.pooler = TFBertPooler(config, name='pooler') + def get_input_embeddings(self): + return self.embeddings + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel): self.nsp = TFBertNSPHead(config, name='nsp___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') + def get_output_embeddings(self): + return self.bert.embeddings + def call(self, inputs, **kwargs): outputs = self.bert(inputs, **kwargs) @@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): self.bert = TFBertMainLayer(config, name='bert') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') + def get_output_embeddings(self): + return self.bert.embeddings + def call(self, inputs, **kwargs): outputs = self.bert(inputs, **kwargs) diff --git a/transformers/modeling_tf_ctrl.py b/transformers/modeling_tf_ctrl.py index dea590c5c5f..99738a8b145 100644 --- a/transformers/modeling_tf_ctrl.py +++ b/transformers/modeling_tf_ctrl.py @@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): name='h_._{}'.format(i)) for i in range(config.n_layer)] self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm") + def get_input_embeddings(self): + return self.w + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel): self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") + def get_output_embeddings(self): + return self.lm_head.input_embeddings + def call(self, inputs, **kwargs): transformer_outputs = self.transformer(inputs, **kwargs) hidden_states = transformer_outputs[0] diff --git a/transformers/modeling_tf_distilbert.py b/transformers/modeling_tf_distilbert.py index 65acb9e142e..4b1f3e676b5 100644 --- a/transformers/modeling_tf_distilbert.py +++ b/transformers/modeling_tf_distilbert.py @@ -398,6 +398,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings self.transformer = TFTransformer(config, name="transformer") # Encoder + def get_input_embeddings(self): + return self.embeddings + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -613,6 +616,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") + def get_output_embeddings(self): + return self.vocab_projector.input_embeddings + def call(self, inputs, **kwargs): distilbert_output = self.distilbert(inputs, **kwargs) diff --git a/transformers/modeling_tf_gpt2.py b/transformers/modeling_tf_gpt2.py index 50d58a6749b..23866a1a0a4 100644 --- a/transformers/modeling_tf_gpt2.py +++ b/transformers/modeling_tf_gpt2.py @@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): name='h_._{}'.format(i)) for i in range(config.n_layer)] self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f') + def get_input_embeddings(self): + return self.wte + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs) self.transformer = TFGPT2MainLayer(config, name='transformer') + def get_output_embeddings(self): + return self.transformer.wte + def call(self, inputs, **kwargs): transformer_outputs = self.transformer(inputs, **kwargs) hidden_states = transformer_outputs[0] @@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): self.transformer = TFGPT2MainLayer(config, name='transformer') self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head') + def get_output_embeddings(self): + return self.transformer.wte + def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] diff --git a/transformers/modeling_tf_openai.py b/transformers/modeling_tf_openai.py index 18afa85dced..bddd9338b17 100644 --- a/transformers/modeling_tf_openai.py +++ b/transformers/modeling_tf_openai.py @@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): scale=True, name='h_._{}'.format(i)) for i in range(config.n_layer)] + def get_input_embeddings(self): + return self.tokens_embed + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -462,6 +465,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel): super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs) self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') + def get_output_embeddings(self): + return self.transformer.tokens_embed + def call(self, inputs, **kwargs): transformer_outputs = self.transformer(inputs, **kwargs) hidden_states = transformer_outputs[0] @@ -524,6 +530,9 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head') + def get_output_embeddings(self): + return self.transformer.tokens_embed + def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] diff --git a/transformers/modeling_tf_roberta.py b/transformers/modeling_tf_roberta.py index 32abea659ef..c335910dc69 100644 --- a/transformers/modeling_tf_roberta.py +++ b/transformers/modeling_tf_roberta.py @@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer): super(TFRobertaMainLayer, self).__init__(config, **kwargs) self.embeddings = TFRobertaEmbeddings(config, name='embeddings') + def get_input_embeddings(self): + return self.embeddings + class TFRobertaPreTrainedModel(TFPreTrainedModel): """ An abstract class to handle weights initialization and @@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel): self.roberta = TFRobertaMainLayer(config, name="roberta") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") + def get_output_embeddings(self): + return self.lm_head.decoder + def call(self, inputs, **kwargs): outputs = self.roberta(inputs, **kwargs) diff --git a/transformers/modeling_tf_transfo_xl.py b/transformers/modeling_tf_transfo_xl.py index ec37aedd745..8c2a35b3527 100644 --- a/transformers/modeling_tf_transfo_xl.py +++ b/transformers/modeling_tf_transfo_xl.py @@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): name='r_r_bias') super(TFTransfoXLMainLayer, self).build(input_shape) + def get_input_embeddings(self): + return self.word_emb + def _resize_token_embeddings(self, new_num_tokens): return self.word_emb diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 446fcad131b..e08605d1548 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model): # Save config in model self.config = config + def get_input_embeddings(self): + """ Get model's input embeddings + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + else: + raise NotImplementedError + + def get_output_embeddings(self): + """ Get model's output embeddings + Return None if the model doesn't have output embeddings + """ + return None # Overwrite for models with output embeddings + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): """ Build a resized Embedding Variable from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end diff --git a/transformers/modeling_tf_xlm.py b/transformers/modeling_tf_xlm.py index 496d7d72a8f..20fbdca7321 100644 --- a/transformers/modeling_tf_xlm.py +++ b/transformers/modeling_tf_xlm.py @@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): self.prune_heads({int(layer): list(map(int, heads))}) + def get_input_embeddings(self): + return self.embeddings + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -641,6 +644,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): self.transformer = TFXLMMainLayer(config, name='transformer') self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj') + def get_output_embeddings(self): + return self.pred_layer.input_embeddings def call(self, inputs, **kwargs): transformer_outputs = self.transformer(inputs, **kwargs) diff --git a/transformers/modeling_tf_xlnet.py b/transformers/modeling_tf_xlnet.py index bb33e45790d..7ab95e7c9fe 100644 --- a/transformers/modeling_tf_xlnet.py +++ b/transformers/modeling_tf_xlnet.py @@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)] self.dropout = tf.keras.layers.Dropout(config.dropout) + def get_input_embeddings(self): + return self.word_embedding + def build(self, input_shape): initializer = get_initializer(self.initializer_range) self.mask_emb = self.add_weight(shape=(1, 1, self.d_model), @@ -854,6 +857,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): self.transformer = TFXLNetMainLayer(config, name='transformer') self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss') + def get_output_embeddings(self): + return self.lm_loss.input_embeddings + def call(self, inputs, **kwargs): transformer_outputs = self.transformer(inputs, **kwargs) hidden_state = transformer_outputs[0] diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index f636c428897..0be5fe8e9cb 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -360,6 +360,16 @@ class TFCommonTestCases: # self.assertTrue(models_equal) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_embeddings() + assert x is None or instanceof(x, tf.keras.layers.Layer) + + def test_tie_model_weights(self): pass # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()