Fix TF Flaubert and XLM (#9661)

* Fix Flaubert and XLM

* Fix Flaubert and XLM

* Apply style
This commit is contained in:
Julien Plu 2021-01-19 18:02:57 +01:00 committed by GitHub
parent 11ec74905a
commit fa876aee2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 10 deletions

View File

@ -214,10 +214,13 @@ class TFFlaubertPreTrainedModel(TFPreTrainedModel):
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
return {
"input_ids": inputs_list,
"attention_mask": attns_list,
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
}
else:
langs_list = None
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
return {"input_ids": inputs_list, "attention_mask": attns_list}
@add_start_docstrings(

View File

@ -536,10 +536,13 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
return {
"input_ids": inputs_list,
"attention_mask": attns_list,
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
}
else:
langs_list = None
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
return {"input_ids": inputs_list, "attention_mask": attns_list}
# Remove when XLMWithLMHead computes loss like other LM models
@ -1045,10 +1048,16 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
Returns:
tf.Tensor with dummy inputs
"""
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
if self.config.use_lang_emb and self.config.n_langs > 1:
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
else:
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(