diff --git a/docs/source/en/model_doc/donut.md b/docs/source/en/model_doc/donut.md index 1bc1a3bcfd0..6e3bd3c51ea 100644 --- a/docs/source/en/model_doc/donut.md +++ b/docs/source/en/model_doc/donut.md @@ -226,3 +226,8 @@ print(answer) [[autodoc]] DonutSwinModel - forward + +## DonutSwinForImageClassification + +[[autodoc]] transformers.DonutSwinForImageClassification + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 35dad4aacbe..71bb1e2bd84 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2304,6 +2304,7 @@ else: ) _import_structure["models.donut"].extend( [ + "DonutSwinForImageClassification", "DonutSwinModel", "DonutSwinPreTrainedModel", ] @@ -7457,6 +7458,7 @@ if TYPE_CHECKING: DistilBertPreTrainedModel, ) from .models.donut import ( + DonutSwinForImageClassification, DonutSwinModel, DonutSwinPreTrainedModel, ) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 0e052aed6a9..98ecbe76520 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -145,6 +145,7 @@ LOSS_MAPPING = { "ForMaskedLM": ForMaskedLMLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss, "ForSequenceClassification": ForSequenceClassificationLoss, + "ForImageClassification": ForSequenceClassificationLoss, "ForTokenClassification": ForTokenClassification, "ForSegmentation": ForSegmentationLoss, "ForObjectDetection": ForObjectDetectionLoss, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 06f57246b88..84df5dc363a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -707,6 +707,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("dinat", "DinatForImageClassification"), ("dinov2", "Dinov2ForImageClassification"), ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), + ("donut-swin", "DonutSwinForImageClassification"), ( "efficientformer", ( diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 0d44069fc8b..b06be3bcf61 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -49,6 +49,10 @@ _CONFIG_FOR_DOC = "DonutSwinConfig" _CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base" _EXPECTED_OUTPUT_SHAPE = [1, 49, 768] +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + @dataclass # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin @@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput): reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin +class DonutSwinImageClassifierOutput(ModelOutput): + """ + DonutSwin outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + # Copied from transformers.models.swin.modeling_swin.window_partition def window_partition(input_feature, window_size): """ @@ -845,7 +886,7 @@ class DonutSwinEncoder(nn.Module): ) -# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut class DonutSwinPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -853,7 +894,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel): """ config_class = DonutSwinConfig - base_model_prefix = "swin" + base_model_prefix = "donut" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["DonutSwinStage"] @@ -1015,4 +1056,90 @@ class DonutSwinModel(DonutSwinPreTrainedModel): ) -__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"] +@add_start_docstrings( + """ + DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + SWIN_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut +class DonutSwinForImageClassification(DonutSwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.donut = DonutSwinModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=DonutSwinImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DonutSwinImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.donut( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DonutSwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"] diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 5de428831e3..e155874d8f0 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput @@ -1285,26 +1284,7 @@ class SwinForImageClassification(SwinPreTrainedModel): loss = None if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 4a5d1bd988c..1e07af70159 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import Tensor, nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput @@ -1339,26 +1338,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel): loss = None if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ad9f373c1bc..b0959e1c8db 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3829,6 +3829,13 @@ class DistilBertPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class DonutSwinForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class DonutSwinModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/donut/test_modeling_donut_swin.py b/tests/models/donut/test_modeling_donut_swin.py index cb7dc99302c..078331389b8 100644 --- a/tests/models/donut/test_modeling_donut_swin.py +++ b/tests/models/donut/test_modeling_donut_swin.py @@ -29,7 +29,7 @@ if is_torch_available(): import torch from torch import nn - from transformers import DonutSwinModel + from transformers import DonutSwinForImageClassification, DonutSwinModel class DonutSwinModelTester: @@ -129,6 +129,24 @@ class DonutSwinModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = DonutSwinForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + # test greyscale images + config.num_channels = 1 + model = DonutSwinForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -142,8 +160,12 @@ class DonutSwinModelTester: @require_torch class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (DonutSwinModel,) if is_torch_available() else () - pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {} + all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification} + if is_torch_available() + else {} + ) fx_compatible = True test_pruning = False @@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + @unittest.skip(reason="DonutSwin does not use inputs_embeds") def test_inputs_embeds(self): pass