mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[vision] Add problem_type support (#15851)
* Add problem_type to missing models * Fix deit test Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
7ff9d450cd
commit
286fdc6b3c
@ -22,7 +22,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -857,14 +857,26 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -23,7 +23,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -726,14 +726,26 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -21,7 +21,7 @@ import math
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -590,14 +590,26 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -21,7 +21,7 @@ import math
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -747,14 +747,26 @@ class ViTForImageClassification(ViTPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from transformers import DeiTConfig
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
@ -32,6 +33,8 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
DeiTForImageClassification,
|
||||
DeiTForImageClassificationWithTeacher,
|
||||
@ -379,6 +382,57 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def test_problem_types(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
problem_types = [
|
||||
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
|
||||
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
|
||||
{"title": "regression", "num_labels": 1, "dtype": torch.float},
|
||||
]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if (
|
||||
model_class
|
||||
not in [
|
||||
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||
]
|
||||
or model_class.__name__ == "DeiTForImageClassificationWithTeacher"
|
||||
):
|
||||
continue
|
||||
|
||||
for problem_type in problem_types:
|
||||
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
|
||||
|
||||
config.problem_type = problem_type["title"]
|
||||
config.num_labels = problem_type["num_labels"]
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
if problem_type["num_labels"] > 1:
|
||||
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
|
||||
|
||||
inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
|
||||
|
||||
# This tests that we do not trigger the warning form PyTorch "Using a target size that is different
|
||||
# to the input size. This will likely lead to incorrect results due to broadcasting. Please ensure
|
||||
# they have the same size." which is a symptom something in wrong for the regression problem.
|
||||
# See https://github.com/huggingface/transformers/issues/11780
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
loss = model(**inputs).loss
|
||||
for w in warning_list:
|
||||
if "Using a target size that is different to the input size" in str(w.message):
|
||||
raise ValueError(
|
||||
f"Something is going wrong in the regression problem: intercepted {w.message}"
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
@ -1873,7 +1873,10 @@ class ModelTesterMixin:
|
||||
]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
if model_class not in [
|
||||
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||
]:
|
||||
continue
|
||||
|
||||
for problem_type in problem_types:
|
||||
|
Loading…
Reference in New Issue
Block a user