Generalize problem_type to all sequence classification models (#14180)

* Generalize problem_type to all classification models

* Missing import

* Deberta BC and fix tests

* Fix template

* Missing imports

* Revert change to reformer test

* Fix style
This commit is contained in:
Sylvain Gugger 2021-10-29 10:32:56 -04:00 committed by GitHub
parent 4ab6a4a086
commit c28bc80bbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 474 additions and 191 deletions

View File

@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1475,14 +1475,26 @@ class BartForSequenceClassification(BartPretrainedModel):
loss = None
if labels is not None:
if self.config.num_labels == 1:
# regression
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -23,7 +23,7 @@ from typing import Optional, Tuple
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -2680,14 +2680,26 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
loss = None
if labels is not None:
if self.config.num_labels == 1:
# regression
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -20,7 +20,7 @@ from typing import Tuple
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
@ -690,14 +690,26 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -19,7 +19,7 @@ from collections.abc import Sequence
import torch
from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -1194,31 +1194,46 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# regression task
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
label_index = (labels >= 0).nonzero()
labels = labels.long()
if label_index.size(0) > 0:
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
if self.config.problem_type is None:
if self.num_labels == 1:
# regression task
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
label_index = (labels >= 0).nonzero()
labels = labels.long()
if label_index.size(0) > 0:
labeled_logits = torch.gather(
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
)
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
else:
loss = torch.tensor(0).to(logits)
else:
loss = torch.tensor(0).to(logits)
else:
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
elif self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
else:
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@add_start_docstrings(

View File

@ -20,7 +20,7 @@ from collections.abc import Sequence
import numpy as np
import torch
from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -1304,31 +1304,46 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# regression task
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
label_index = (labels >= 0).nonzero()
labels = labels.long()
if label_index.size(0) > 0:
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
if self.config.problem_type is None:
if self.num_labels == 1:
# regression task
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
label_index = (labels >= 0).nonzero()
labels = labels.long()
if label_index.size(0) > 0:
labeled_logits = torch.gather(
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
)
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
else:
loss = torch.tensor(0).to(logits)
else:
loss = torch.tensor(0).to(logits)
else:
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
elif self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
else:
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@add_start_docstrings(

View File

@ -23,7 +23,7 @@ import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import is_scipy_available
@ -927,14 +927,26 @@ class FNetForSequenceClassification(FNetPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -24,7 +24,7 @@ import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
if version.parse(torch.__version__) >= version.parse("1.6"):
@ -1406,14 +1406,26 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -21,7 +21,7 @@ from typing import Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -895,14 +895,26 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -19,7 +19,7 @@ from typing import Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -931,14 +931,26 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -22,7 +22,7 @@ import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -1025,14 +1025,26 @@ class IBertForSequenceClassification(IBertPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -20,7 +20,7 @@ import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
@ -1059,14 +1059,26 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -20,7 +20,7 @@ import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1076,14 +1076,26 @@ class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -23,7 +23,7 @@ from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -2536,9 +2536,26 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -21,7 +21,7 @@ from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1475,14 +1475,26 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
loss = None
if labels is not None:
if self.config.num_labels == 1:
# regression
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -25,7 +25,7 @@ from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1525,14 +1525,26 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -20,7 +20,7 @@ import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -736,14 +736,26 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -24,7 +24,7 @@ from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu_new, silu
from ...file_utils import (
@ -823,14 +823,26 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -21,7 +21,7 @@ import os
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1220,14 +1220,26 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -23,7 +23,7 @@ import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1287,14 +1287,26 @@ class RoFormerForSequenceClassification(RoFormerPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -24,7 +24,7 @@ from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1532,14 +1532,26 @@ class TapasForSequenceClassification(TapasPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

View File

@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import (
ModelOutput,
@ -1234,13 +1234,26 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
loss = None
if labels is not None:
if self.num_labels == 1:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -24,7 +24,7 @@ import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -1265,14 +1265,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
@ -1564,7 +1576,7 @@ from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
@ -2981,9 +2993,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -234,8 +234,6 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

View File

@ -446,7 +446,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -435,7 +435,6 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
# head masking & pruning is currently not supported for big bird
test_head_masking = False
test_pruning = False
test_sequence_classification_problem_types = True
# torchscript should be possible, but takes prohibitively long to test.
# Also torchscript is not an important feature to have in the beginning.

View File

@ -113,7 +113,6 @@ class ModelTesterMixin:
test_missing_keys = True
test_model_parallel = False
is_encoder_decoder = False
test_sequence_classification_problem_types = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
@ -387,12 +386,13 @@ class ModelTesterMixin:
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
model.train()
@ -401,14 +401,14 @@ class ModelTesterMixin:
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
model = model_class(config)
@ -1842,9 +1842,6 @@ class ModelTesterMixin:
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def test_problem_types(self):
if not self.test_sequence_classification_problem_types:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
problem_types = [
@ -1880,7 +1877,11 @@ class ModelTesterMixin:
# See https://github.com/huggingface/transformers/issues/11780
with warnings.catch_warnings(record=True) as warning_list:
loss = model(**inputs).loss
self.assertListEqual(warning_list, [])
for w in warning_list:
if "Using a target size that is different to the input size" in str(w.message):
raise ValueError(
f"Something is going wrong in the regression problem: intercepted {w.message}"
)
loss.backward()
@ -2184,7 +2185,6 @@ class ModelPushToHubTester(unittest.TestCase):
f.write(FAKE_MODEL_CODE)
repo.push_to_hub()
print(os.listdir(tmp_dir))
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()):

View File

@ -262,7 +262,6 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
)
test_pruning = False
test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = ConvBertModelTester(self)

View File

@ -214,7 +214,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_sequence_classification_problem_types = True
test_resize_position_embeddings = True
def setUp(self):

View File

@ -291,7 +291,6 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -362,7 +362,6 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -278,7 +278,6 @@ class LongformerModelTester:
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_torchscript = False
test_sequence_classification_problem_types = True
all_model_classes = (
(

View File

@ -271,7 +271,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -143,7 +143,7 @@ class OpenAIGPTModelTester:
model = OpenAIGPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
# print(config.num_labels, sequence_labels.size())
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

View File

@ -795,6 +795,10 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
[expected_shape] * len(iter_hidden_states),
)
def test_problem_types(self):
# Fails because the sequence length is not a multiple of 4
pass
@require_torch
@require_sentencepiece

View File

@ -356,7 +356,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else ()
)
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = RobertaModelTester(self)

View File

@ -232,7 +232,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = True
test_resize_embeddings = True
test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = SqueezeBertModelTester(self)

View File

@ -350,7 +350,6 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_sequence_classification_problem_types = True
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -527,7 +527,6 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
test_sequence_classification_problem_types = True
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):