mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Tests] Add Common Test for Training + Fix a couple of bugs (#8415)
* add training tests * correct longformer * fix docs * fix some tests * fix some more train tests * remove ipdb * fix multiple edge case model training * fix funnel and prophetnet * clean gpt models * undo renaming of albert
This commit is contained in:
parent
52040517b8
commit
9c83b96e62
@ -81,6 +81,13 @@ AutoModelForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForNextSentencePrediction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForNextSentencePrediction
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -1801,7 +1801,7 @@ class GeneralizedRCNN(nn.Module):
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ For further information or requests, please go to [BERTimbau repository](https:/
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer # Or BertTokenizer
|
||||
from transformers import AutoModelForPretraining # Or BertForPreTraining for loading pretraining heads
|
||||
from transformers import AutoModelForPreTraining # Or BertForPreTraining for loading pretraining heads
|
||||
from transformers import AutoModel # or BertModel, for BERT without pretraining heads
|
||||
|
||||
model = AutoModelForPreTraining.from_pretrained('neuralmind/bert-base-portuguese-cased')
|
||||
|
@ -329,6 +329,7 @@ if is_torch_available():
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
@ -340,6 +341,7 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForNextSentencePrediction,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
@ -77,6 +77,7 @@ from .modeling_bart import (
|
||||
from .modeling_bert import (
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
@ -128,6 +129,7 @@ from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
|
||||
from .modeling_funnel import (
|
||||
FunnelForMaskedLM,
|
||||
FunnelForMultipleChoice,
|
||||
FunnelForPreTraining,
|
||||
FunnelForQuestionAnswering,
|
||||
FunnelForSequenceClassification,
|
||||
FunnelForTokenClassification,
|
||||
@ -143,12 +145,13 @@ from .modeling_longformer import (
|
||||
LongformerForTokenClassification,
|
||||
LongformerModel,
|
||||
)
|
||||
from .modeling_lxmert import LxmertForPreTraining, LxmertModel
|
||||
from .modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
MobileBertForNextSentencePrediction,
|
||||
MobileBertForPreTraining,
|
||||
MobileBertForQuestionAnswering,
|
||||
MobileBertForSequenceClassification,
|
||||
@ -166,6 +169,7 @@ from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be i
|
||||
from .modeling_reformer import (
|
||||
ReformerForMaskedLM,
|
||||
ReformerForQuestionAnswering,
|
||||
ReformerForSequenceClassification,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
)
|
||||
@ -285,6 +289,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(CTRLConfig, CTRLLMHeadModel),
|
||||
(ElectraConfig, ElectraForPreTraining),
|
||||
(LxmertConfig, LxmertForPreTraining),
|
||||
(FunnelConfig, FunnelForPreTraining),
|
||||
]
|
||||
)
|
||||
|
||||
@ -396,6 +401,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(DebertaConfig, DebertaForSequenceClassification),
|
||||
(GPT2Config, GPT2ForSequenceClassification),
|
||||
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
||||
(ReformerConfig, ReformerForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
@ -417,6 +423,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
(ElectraConfig, ElectraForQuestionAnswering),
|
||||
(ReformerConfig, ReformerForQuestionAnswering),
|
||||
(FunnelConfig, FunnelForQuestionAnswering),
|
||||
(LxmertConfig, LxmertForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
@ -460,6 +467,13 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
[
|
||||
(BertConfig, BertForNextSentencePrediction),
|
||||
(MobileBertConfig, MobileBertForNextSentencePrediction),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
|
||||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||
@ -1519,3 +1533,103 @@ class AutoModelForMultipleChoice:
|
||||
", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForNextSentencePrediction:
|
||||
r"""
|
||||
This is a generic model class that will be instantiated as one of the model classes of the library---with a
|
||||
multiple choice classification head---when created with the when created with the
|
||||
:meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` class method or the
|
||||
:meth:`~transformers.AutoModelForNextSentencePrediction.from_config` class method.
|
||||
|
||||
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoModelForNextSentencePrediction is designed to be instantiated "
|
||||
"using the `AutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForNextSentencePrediction.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
|
||||
def from_config(cls, config):
|
||||
r"""
|
||||
Instantiates one of the model classes of the library---with a multiple choice classification head---from a
|
||||
configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||
model's configuration. Use :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` to load
|
||||
the model weights.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
List options
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
|
||||
>>> # Download configuration from S3 and cache.
|
||||
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
|
||||
>>> model = AutoModelForNextSentencePrediction.from_config(config)
|
||||
"""
|
||||
if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
|
||||
return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
|
||||
@add_start_docstrings(
|
||||
"Instantiate one of the model classes of the library---with a multiple choice classification head---from a "
|
||||
"pretrained model.",
|
||||
AUTO_MODEL_PRETRAINED_DOCSTRING,
|
||||
)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
|
||||
|
||||
>>> # Download model and configuration from S3 and cache.
|
||||
>>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||
|
||||
>>> # Update configuration during loading
|
||||
>>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
|
||||
>>> model.config.output_attentions
|
||||
True
|
||||
|
||||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
>>> config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||
>>> model = AutoModelForNextSentencePrediction.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
|
||||
return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
@ -1228,13 +1228,14 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
next_sentence_label=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||
(see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
|
||||
|
||||
@ -1255,10 +1256,18 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
||||
|
||||
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
||||
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
||||
>>> logits = outputs.logits
|
||||
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||
"""
|
||||
|
||||
if "next_sentence_label" in kwargs:
|
||||
warnings.warn(
|
||||
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("next_sentence_label")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.bert(
|
||||
@ -1278,9 +1287,9 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
|
||||
next_sentence_loss = None
|
||||
if next_sentence_label is not None:
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
|
||||
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (seq_relationship_scores,) + outputs[2:]
|
||||
|
@ -1069,7 +1069,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
@ -1069,7 +1069,7 @@ class LongformerEncoder(nn.Module):
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
return module(*inputs, is_global_attn)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@ -1079,7 +1079,6 @@ class LongformerEncoder(nn.Module):
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -1154,16 +1155,17 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
|
||||
visual_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
labels=None,
|
||||
obj_labels=None,
|
||||
matched_label=None,
|
||||
ans=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
|
||||
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
|
||||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
@ -1183,6 +1185,15 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
|
||||
Returns:
|
||||
"""
|
||||
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
lxmert_output = self.lxmert(
|
||||
input_ids=input_ids,
|
||||
@ -1210,13 +1221,13 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
|
||||
|
||||
total_loss = (
|
||||
None
|
||||
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
|
||||
if (labels is None and matched_label is None and obj_labels is None and ans is None)
|
||||
else torch.tensor(0.0, device=device)
|
||||
)
|
||||
if masked_lm_labels is not None and self.task_mask_lm:
|
||||
if labels is not None and self.task_mask_lm:
|
||||
masked_lm_loss = self.loss_fcts["ce"](
|
||||
lang_prediction_scores.view(-1, self.config.vocab_size),
|
||||
masked_lm_labels.view(-1),
|
||||
labels.view(-1),
|
||||
)
|
||||
total_loss += masked_lm_loss
|
||||
if matched_label is not None and self.task_matched:
|
||||
@ -1391,6 +1402,7 @@ class LxmertForQuestionAnswering(LxmertPreTrainedModel):
|
||||
|
||||
Returns:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
lxmert_output = self.lxmert(
|
||||
input_ids=input_ids,
|
||||
|
@ -1194,13 +1194,14 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
next_sentence_label=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||
(see ``input_ids`` docstring) Indices should be in ``[0, 1]``.
|
||||
|
||||
@ -1221,10 +1222,18 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
||||
|
||||
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
||||
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
if "next_sentence_label" in kwargs:
|
||||
warnings.warn(
|
||||
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("next_sentence_label")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
@ -1243,9 +1252,9 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
seq_relationship_score = self.cls(pooled_output)
|
||||
|
||||
next_sentence_loss = None
|
||||
if next_sentence_label is not None:
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (seq_relationship_score,) + outputs[2:]
|
||||
|
@ -824,7 +824,7 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
@ -221,7 +221,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
f"Some weights of the PyTorch model were not used when "
|
||||
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect "
|
||||
f"to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
@ -375,7 +375,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
f"Some weights of the TF 2.0 model were not used when "
|
||||
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect "
|
||||
f"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
|
||||
)
|
||||
|
@ -730,7 +730,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
|
@ -1047,7 +1047,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
|
@ -256,6 +256,9 @@ MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_PRETRAINING_MAPPING = None
|
||||
|
||||
|
||||
@ -313,6 +316,15 @@ class AutoModelForMultipleChoice:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class AutoModelForNextSentencePrediction:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class AutoModelForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -24,7 +24,10 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
AlbertConfig,
|
||||
AlbertForMaskedLM,
|
||||
AlbertForMultipleChoice,
|
||||
@ -227,6 +230,20 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["sentence_order_label"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = AlbertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
|
||||
|
@ -25,7 +25,10 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
@ -268,7 +271,7 @@ class BertModelTester:
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
next_sentence_label=sequence_labels,
|
||||
labels=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
|
||||
|
||||
@ -377,6 +380,20 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["next_sentence_label"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
|
||||
|
@ -35,10 +35,12 @@ if is_torch_available():
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AdaptiveEmbedding,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
@ -88,7 +90,10 @@ class ModelTesterMixin:
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
elif model_class in [
|
||||
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
@ -204,6 +209,41 @@ class ModelTesterMixin:
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in MODEL_MAPPING.values():
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"):
|
||||
return
|
||||
|
||||
config.gradient_checkpointing = True
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in MODEL_MAPPING.values():
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
@ -38,7 +38,7 @@ class DPRModelTester:
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
is_training=False,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
|
@ -24,7 +24,10 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
ElectraConfig,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForMultipleChoice,
|
||||
@ -285,6 +288,17 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ElectraModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ElectraConfig, hidden_size=37)
|
||||
|
@ -24,6 +24,8 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
FlaubertConfig,
|
||||
FlaubertForMultipleChoice,
|
||||
@ -343,6 +345,21 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
# Flaubert has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "FlaubertForQuestionAnswering":
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaubertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=FlaubertConfig, emb_dim=37)
|
||||
|
@ -27,6 +27,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
FunnelBaseModel,
|
||||
FunnelConfig,
|
||||
FunnelForMaskedLM,
|
||||
@ -360,6 +361,17 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FunnelModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=FunnelConfig)
|
||||
@ -415,6 +427,21 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def test_training(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == "FunnelBaseModel":
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
@ -388,6 +388,29 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
test_missing_keys = False
|
||||
|
||||
# special case for DoubleHeads model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "GPT2DoubleHeadsModel":
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
inputs_dict["input_ids"] = inputs_dict["labels"]
|
||||
inputs_dict["token_type_ids"] = inputs_dict["labels"]
|
||||
inputs_dict["mc_token_ids"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.num_choices),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
inputs_dict["mc_labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GPT2ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
@ -26,7 +27,14 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import LxmertConfig, LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
LxmertConfig,
|
||||
LxmertForPreTraining,
|
||||
LxmertForQuestionAnswering,
|
||||
LxmertModel,
|
||||
)
|
||||
from transformers.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@ -533,6 +541,22 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
||||
# overwrite function because qa models takes different input label shape
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
# special case for models like BERT that use multi-loss training for PreTraining
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LxmertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LxmertConfig, hidden_size=37)
|
||||
|
@ -27,6 +27,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MobileBertConfig,
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
@ -220,7 +221,7 @@ class MobileBertModelTester:
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
next_sentence_label=sequence_labels,
|
||||
labels=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
|
||||
|
||||
@ -327,6 +328,20 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["next_sentence_label"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
|
||||
|
@ -182,6 +182,29 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
(OpenAIGPTLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
|
||||
# special case for DoubleHeads model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "OpenAIGPTDoubleHeadsModel":
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
inputs_dict["input_ids"] = inputs_dict["labels"]
|
||||
inputs_dict["token_type_ids"] = inputs_dict["labels"]
|
||||
inputs_dict["mc_token_ids"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.num_choices),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
inputs_dict["mc_labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = OpenAIGPTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37)
|
||||
|
@ -1038,7 +1038,7 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
|
||||
is_encoder_decoder = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ProphetNetStandaloneDecoderModelTester(self)
|
||||
self.model_tester = ProphetNetStandaloneDecoderModelTester(self, is_training=False)
|
||||
self.config_tester = ConfigTester(self, config_class=ProphetNetConfig)
|
||||
|
||||
def test_config(self):
|
||||
@ -1063,7 +1063,7 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ProphetNetStandaloneEncoderModelTester(self)
|
||||
self.model_tester = ProphetNetStandaloneEncoderModelTester(self, is_training=False)
|
||||
self.config_tester = ConfigTester(self, config_class=ProphetNetConfig)
|
||||
|
||||
def test_config(self):
|
||||
|
@ -42,7 +42,7 @@ class TransfoXLModelTester:
|
||||
self.mem_len = 30
|
||||
self.key_length = self.seq_length + self.mem_len
|
||||
self.clamp_len = 15
|
||||
self.is_training = True
|
||||
self.is_training = False
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.cutoffs = [10, 50, 80]
|
||||
|
@ -351,6 +351,21 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
(XLMWithLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
|
||||
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "XLMForQuestionAnswering":
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = XLMModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
|
||||
|
@ -499,6 +499,21 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_pruning = False
|
||||
|
||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "XLNetForQuestionAnswering":
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = XLNetModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
|
||||
|
Loading…
Reference in New Issue
Block a user