mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TF Flaubert and XLM (#9661)
* Fix Flaubert and XLM * Fix Flaubert and XLM * Apply style
This commit is contained in:
parent
11ec74905a
commit
fa876aee2a
@ -214,10 +214,13 @@ class TFFlaubertPreTrainedModel(TFPreTrainedModel):
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
return {
|
||||
"input_ids": inputs_list,
|
||||
"attention_mask": attns_list,
|
||||
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
|
||||
}
|
||||
else:
|
||||
langs_list = None
|
||||
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
|
||||
return {"input_ids": inputs_list, "attention_mask": attns_list}
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -536,10 +536,13 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
return {
|
||||
"input_ids": inputs_list,
|
||||
"attention_mask": attns_list,
|
||||
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
|
||||
}
|
||||
else:
|
||||
langs_list = None
|
||||
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
|
||||
return {"input_ids": inputs_list, "attention_mask": attns_list}
|
||||
|
||||
|
||||
# Remove when XLMWithLMHead computes loss like other LM models
|
||||
@ -1045,10 +1048,16 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
||||
Returns:
|
||||
tf.Tensor with dummy inputs
|
||||
"""
|
||||
return {
|
||||
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
}
|
||||
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
return {
|
||||
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
}
|
||||
|
||||
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
Loading…
Reference in New Issue
Block a user