mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add multi-class, multi-label and regression to transformers (#11012)
* add to bert * review comments * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * self.config.problem_type * fix style * fix * fin * fix * update doc * fix * test * Test more problem types * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * remove * fix * quality * make fix-copies * remove test Co-authored-by: abhishek thakur <abhishekkrthakur@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
7c622482e8
commit
c40c7e213b
@ -163,6 +163,14 @@ class PretrainedConfig(PushToHubMixin):
|
||||
typically for a classification task.
|
||||
- **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
|
||||
current task.
|
||||
- **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can
|
||||
be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`).
|
||||
Please note that this parameter is only available in the following models: `AlbertForSequenceClassification`,
|
||||
`BertForSequenceClassification`, `BigBirdForSequenceClassification`, `ConvBertForSequenceClassification`,
|
||||
`DistilBertForSequenceClassification`, `ElectraForSequenceClassification`, `FunnelForSequenceClassification`,
|
||||
`LongformerForSequenceClassification`, `MobileBertForSequenceClassification`,
|
||||
`ReformerForSequenceClassification`, `RobertaForSequenceClassification`,
|
||||
`SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and `XLNetForSequenceClassification`.
|
||||
|
||||
Parameters linked to the tokenizer
|
||||
|
||||
@ -260,6 +268,15 @@ class PretrainedConfig(PushToHubMixin):
|
||||
# task specific arguments
|
||||
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
||||
|
||||
# regression / multi-label classification
|
||||
self.problem_type = kwargs.pop("problem_type", None)
|
||||
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
|
||||
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
|
||||
raise ValueError(
|
||||
f"The config parameter `problem_type` wasnot understood: received {self.problem_type}"
|
||||
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
||||
)
|
||||
|
||||
# TPU arguments
|
||||
if kwargs.pop("xla_device", None) is not None:
|
||||
logger.warning(
|
||||
|
@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -970,6 +970,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.albert = AlbertModel(config)
|
||||
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
||||
@ -1024,13 +1025,23 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -25,7 +25,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -1381,7 +1381,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@ -1463,6 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
@ -1517,14 +1518,23 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -25,7 +25,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -2609,6 +2609,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
self.bert = BigBirdModel(config)
|
||||
self.classifier = BigBirdClassificationHead(config)
|
||||
|
||||
@ -2659,13 +2660,23 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -22,7 +22,7 @@ from operator import attrgetter
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, get_activation
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
@ -962,6 +962,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
self.convbert = ConvBertModel(config)
|
||||
self.classifier = ConvBertClassificationHead(config)
|
||||
|
||||
@ -1012,13 +1013,23 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -24,7 +24,7 @@ import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import gelu
|
||||
from ...file_utils import (
|
||||
@ -579,6 +579,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.distilbert = DistilBertModel(config)
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
@ -631,12 +632,23 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
loss_fct = nn.MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + distilbert_output[1:]
|
||||
|
@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, get_activation
|
||||
from ...file_utils import (
|
||||
@ -903,6 +903,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
self.electra = ElectraModel(config)
|
||||
self.classifier = ElectraClassificationHead(config)
|
||||
|
||||
@ -953,13 +954,23 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + discriminator_hidden_states[1:]
|
||||
|
@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -1240,6 +1240,7 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.funnel = FunnelBaseModel(config)
|
||||
self.classifier = FunnelClassificationHead(config, config.num_labels)
|
||||
@ -1287,13 +1288,23 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
@ -1803,6 +1803,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.longformer = LongformerModel(config, add_pooling_layer=False)
|
||||
self.classifier = LongformerClassificationHead(config)
|
||||
@ -1861,13 +1862,23 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -29,7 +29,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -1214,6 +1214,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.mobilebert = MobileBertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
@ -1268,14 +1269,23 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -26,7 +26,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd.function import Function
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -366,7 +366,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
past_buckets_states=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
sequence_length = hidden_states.shape[1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
@ -1045,7 +1045,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
past_buckets_states=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
sequence_length = hidden_states.shape[1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
@ -2381,6 +2381,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.reformer = ReformerModel(config)
|
||||
self.classifier = ReformerClassificationHead(config)
|
||||
@ -2434,13 +2435,23 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -20,7 +20,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...file_utils import (
|
||||
@ -1117,6 +1117,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
@ -1167,13 +1168,23 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -19,7 +19,7 @@ import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
@ -733,6 +733,7 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.transformer = SqueezeBertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
@ -787,13 +788,23 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -24,7 +24,7 @@ from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...activations import gelu
|
||||
@ -779,6 +779,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.transformer = XLMModel(config)
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
@ -836,13 +837,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
|
@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -1488,6 +1488,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
|
||||
self.transformer = XLNetModel(config)
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
@ -1551,13 +1552,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
|
@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# head masking & pruning is currently not supported for big bird
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# torchscript should be possible, but takes prohibitively long to test.
|
||||
# Also torchscript is not an important feature to have in the beginning.
|
||||
|
@ -89,6 +89,7 @@ class ModelTesterMixin:
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = False
|
||||
test_sequence_classification_problem_types = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
@ -1238,6 +1239,42 @@ class ModelTesterMixin:
|
||||
model.parallelize()
|
||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||
|
||||
def test_problem_types(self):
|
||||
if not self.test_sequence_classification_problem_types:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
problem_types = [
|
||||
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
|
||||
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
|
||||
{"title": "regression", "num_labels": 1, "dtype": torch.float},
|
||||
]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
continue
|
||||
|
||||
for problem_type in problem_types:
|
||||
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
|
||||
|
||||
config.problem_type = problem_type["title"]
|
||||
config.num_labels = problem_type["num_labels"]
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
if problem_type["num_labels"] > 1:
|
||||
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
|
||||
|
||||
inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
|
||||
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ConvBertModelTester(self)
|
||||
|
@ -211,6 +211,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DistilBertModelTester(self)
|
||||
|
@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -360,6 +360,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -274,6 +274,7 @@ class LongformerModelTester:
|
||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False # pruning is not supported
|
||||
test_torchscript = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
|
@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -590,6 +590,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def prepare_kwargs(self):
|
||||
return {
|
||||
|
@ -351,6 +351,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTester(self)
|
||||
|
@ -231,6 +231,7 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SqueezeBertModelTester(self)
|
||||
|
@ -349,6 +349,7 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (
|
||||
(XLMWithLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_pruning = False
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
Loading…
Reference in New Issue
Block a user