mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add image classifier donut & update loss calculation for all swins (#37224)
* add classifier head to donut * add to transformers __init__ * add to auto model * fix typo * add loss for image classification * add checkpoint * remove no needed import * reoder import * format * consistency * add test of classifier * add doc * try ignore * update loss for all swin models
This commit is contained in:
parent
5ae9b2cac0
commit
7ecc5b88c0
@ -226,3 +226,8 @@ print(answer)
|
||||
|
||||
[[autodoc]] DonutSwinModel
|
||||
- forward
|
||||
|
||||
## DonutSwinForImageClassification
|
||||
|
||||
[[autodoc]] transformers.DonutSwinForImageClassification
|
||||
- forward
|
@ -2304,6 +2304,7 @@ else:
|
||||
)
|
||||
_import_structure["models.donut"].extend(
|
||||
[
|
||||
"DonutSwinForImageClassification",
|
||||
"DonutSwinModel",
|
||||
"DonutSwinPreTrainedModel",
|
||||
]
|
||||
@ -7457,6 +7458,7 @@ if TYPE_CHECKING:
|
||||
DistilBertPreTrainedModel,
|
||||
)
|
||||
from .models.donut import (
|
||||
DonutSwinForImageClassification,
|
||||
DonutSwinModel,
|
||||
DonutSwinPreTrainedModel,
|
||||
)
|
||||
|
@ -145,6 +145,7 @@ LOSS_MAPPING = {
|
||||
"ForMaskedLM": ForMaskedLMLoss,
|
||||
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
||||
"ForSequenceClassification": ForSequenceClassificationLoss,
|
||||
"ForImageClassification": ForSequenceClassificationLoss,
|
||||
"ForTokenClassification": ForTokenClassification,
|
||||
"ForSegmentation": ForSegmentationLoss,
|
||||
"ForObjectDetection": ForObjectDetectionLoss,
|
||||
|
@ -707,6 +707,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("dinat", "DinatForImageClassification"),
|
||||
("dinov2", "Dinov2ForImageClassification"),
|
||||
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
|
||||
("donut-swin", "DonutSwinForImageClassification"),
|
||||
(
|
||||
"efficientformer",
|
||||
(
|
||||
|
@ -49,6 +49,10 @@ _CONFIG_FOR_DOC = "DonutSwinConfig"
|
||||
_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
|
||||
@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput):
|
||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin
|
||||
class DonutSwinImageClassifierOutput(ModelOutput):
|
||||
"""
|
||||
DonutSwin outputs for image classification.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, hidden_size, height, width)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||||
include the spatial dimensions.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.window_partition
|
||||
def window_partition(input_feature, window_size):
|
||||
"""
|
||||
@ -845,7 +886,7 @@ class DonutSwinEncoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut
|
||||
class DonutSwinPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -853,7 +894,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = DonutSwinConfig
|
||||
base_model_prefix = "swin"
|
||||
base_model_prefix = "donut"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DonutSwinStage"]
|
||||
@ -1015,4 +1056,90 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"]
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||
the [CLS] token) e.g. for ImageNet.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
|
||||
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||
position embeddings to the higher resolution.
|
||||
|
||||
</Tip>
|
||||
""",
|
||||
SWIN_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut
|
||||
class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.donut = DonutSwinModel(config)
|
||||
|
||||
# Classifier head
|
||||
self.classifier = (
|
||||
nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=DonutSwinImageClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, DonutSwinImageClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.donut(
|
||||
pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return DonutSwinImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"]
|
||||
|
@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
@ -1285,26 +1284,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
@ -1339,26 +1338,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -3829,6 +3829,13 @@ class DistilBertPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DonutSwinForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DonutSwinModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -29,7 +29,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import DonutSwinModel
|
||||
from transformers import DonutSwinForImageClassification, DonutSwinModel
|
||||
|
||||
|
||||
class DonutSwinModelTester:
|
||||
@ -129,6 +129,24 @@ class DonutSwinModelTester:
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = DonutSwinForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
model = DonutSwinForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -142,8 +160,12 @@ class DonutSwinModelTester:
|
||||
|
||||
@require_torch
|
||||
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
||||
all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user