Add image classifier donut & update loss calculation for all swins (#37224)

* add classifier head to donut

* add to transformers __init__

* add to auto model

* fix typo

* add loss for image classification

* add checkpoint

* remove no needed import

* reoder import

* format

* consistency

* add test of classifier

* add doc

* try ignore

* update loss for all swin models
This commit is contained in:
AbdelKarim ELJANDOUBI 2025-04-10 15:00:42 +02:00 committed by GitHub
parent 5ae9b2cac0
commit 7ecc5b88c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 177 additions and 48 deletions

View File

@ -226,3 +226,8 @@ print(answer)
[[autodoc]] DonutSwinModel
- forward
## DonutSwinForImageClassification
[[autodoc]] transformers.DonutSwinForImageClassification
- forward

View File

@ -2304,6 +2304,7 @@ else:
)
_import_structure["models.donut"].extend(
[
"DonutSwinForImageClassification",
"DonutSwinModel",
"DonutSwinPreTrainedModel",
]
@ -7457,6 +7458,7 @@ if TYPE_CHECKING:
DistilBertPreTrainedModel,
)
from .models.donut import (
DonutSwinForImageClassification,
DonutSwinModel,
DonutSwinPreTrainedModel,
)

View File

@ -145,6 +145,7 @@ LOSS_MAPPING = {
"ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForImageClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForSegmentation": ForSegmentationLoss,
"ForObjectDetection": ForObjectDetectionLoss,

View File

@ -707,6 +707,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("dinat", "DinatForImageClassification"),
("dinov2", "Dinov2ForImageClassification"),
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
("donut-swin", "DonutSwinForImageClassification"),
(
"efficientformer",
(

View File

@ -49,6 +49,10 @@ _CONFIG_FOR_DOC = "DonutSwinConfig"
_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput):
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin
class DonutSwinImageClassifierOutput(ModelOutput):
"""
DonutSwin outputs for image classification.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
@ -845,7 +886,7 @@ class DonutSwinEncoder(nn.Module):
)
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut
class DonutSwinPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -853,7 +894,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
"""
config_class = DonutSwinConfig
base_model_prefix = "swin"
base_model_prefix = "donut"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["DonutSwinStage"]
@ -1015,4 +1056,90 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
)
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"]
@add_start_docstrings(
"""
DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
SWIN_START_DOCSTRING,
)
# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut
class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.donut = DonutSwinModel(config)
# Classifier head
self.classifier = (
nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=DonutSwinImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, DonutSwinImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.donut(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return DonutSwinImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"]

View File

@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput
@ -1285,26 +1284,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
if not return_dict:
output = (logits,) + outputs[2:]

View File

@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput
@ -1339,26 +1338,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
if not return_dict:
output = (logits,) + outputs[2:]

View File

@ -3829,6 +3829,13 @@ class DistilBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class DonutSwinForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DonutSwinModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -29,7 +29,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import DonutSwinModel
from transformers import DonutSwinForImageClassification, DonutSwinModel
class DonutSwinModelTester:
@ -129,6 +129,24 @@ class DonutSwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = DonutSwinForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = DonutSwinForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -142,8 +160,12 @@ class DonutSwinModelTester:
@require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True
test_pruning = False
@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
def test_inputs_embeds(self):
pass