From e45eba3b1c0537e01a0f6b85099c4de47798e186 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 12 Jan 2021 15:26:32 +0100 Subject: [PATCH] Improve LayoutLM (#9476) * Add LayoutLMForSequenceClassification and integration tests Improve docs Add LayoutLM notebook to list of community notebooks * Make style & quality * Address comments by @sgugger, @patrickvonplaten and @LysandreJik * Fix rebase with master * Reformat in one line * Improve code examples as requested by @patrickvonplaten Co-authored-by: Lysandre Co-authored-by: Lysandre Debut --- docs/source/model_doc/layoutlm.rst | 69 +++- notebooks/README.md | 3 +- src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 8 +- src/transformers/models/layoutlm/__init__.py | 2 + .../models/layoutlm/modeling_layoutlm.py | 303 ++++++++++++++---- src/transformers/utils/dummy_pt_objects.py | 9 + tests/test_modeling_layoutlm.py | 147 +++++++-- 8 files changed, 438 insertions(+), 105 deletions(-) diff --git a/docs/source/model_doc/layoutlm.rst b/docs/source/model_doc/layoutlm.rst index 1dc9ee971f4..413af4ca70d 100644 --- a/docs/source/model_doc/layoutlm.rst +++ b/docs/source/model_doc/layoutlm.rst @@ -13,32 +13,72 @@ LayoutLM ----------------------------------------------------------------------------------------------------------------------- +.. _Overview: + Overview ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The LayoutLM model was proposed in the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, and Ming Zhou. It's a simple but effective pretraining method of text and layout for document image understanding and -information extraction tasks, such as form understanding and receipt understanding. +information extraction tasks, such as form understanding and receipt understanding. It obtains state-of-the-art results +on several downstream tasks: + +- form understanding: the `FUNSD `__ dataset (a collection of 199 annotated + forms comprising more than 30,000 words). +- receipt understanding: the `SROIE `__ dataset (a collection of 626 receipts for + training and 347 receipts for testing). +- document image classification: the `RVL-CDIP `__ dataset (a collection of + 400,000 images belonging to one of 16 classes). The abstract from the paper is the following: *Pre-training techniques have been verified successfully in a variety of NLP tasks in recent years. Despite the widespread use of pretraining models for NLP applications, they almost exclusively focus on text-level manipulation, while neglecting layout and style information that is vital for document image understanding. In this paper, we propose -the \textbf{LayoutLM} to jointly model interactions between text and layout information across scanned document images, -which is beneficial for a great number of real-world document image understanding tasks such as information extraction -from scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into -LayoutLM. To the best of our knowledge, this is the first time that text and layout are jointly learned in a single -framework for document-level pretraining. It achieves new state-of-the-art results in several downstream tasks, -including form understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image -classification (from 93.07 to 94.42).* +the LayoutLM to jointly model interactions between text and layout information across scanned document images, which is +beneficial for a great number of real-world document image understanding tasks such as information extraction from +scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into LayoutLM. +To the best of our knowledge, this is the first time that text and layout are jointly learned in a single framework for +document-level pretraining. It achieves new state-of-the-art results in several downstream tasks, including form +understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image classification +(from 93.07 to 94.42).* Tips: -- LayoutLM has an extra input called :obj:`bbox`, which is the bounding boxes of the input tokens. -- The :obj:`bbox` requires the data that on 0-1000 scale, which means you should normalize the bounding box before - passing them into model. +- In addition to `input_ids`, :meth:`~transformer.LayoutLMModel.forward` also expects the input :obj:`bbox`, which are + the bounding boxes (i.e. 2D-positions) of the input tokens. These can be obtained using an external OCR engine such + as Google's `Tesseract `__ (there's a `Python wrapper + `__ available). Each bounding box should be in (x0, y0, x1, y1) format, where + (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, y1) represents the + position of the lower right corner. Note that one first needs to normalize the bounding boxes to be on a 0-1000 + scale. To normalize, you can use the following function: + +.. code-block:: + + def normalize_bbox(bbox, width, height): + return [ + int(1000 * (bbox[0] / width)), + int(1000 * (bbox[1] / height)), + int(1000 * (bbox[2] / width)), + int(1000 * (bbox[3] / height)), + ] + +Here, :obj:`width` and :obj:`height` correspond to the width and height of the original document in which the token +occurs. Those can be obtained using the Python Image Library (PIL) library for example, as follows: + +.. code-block:: + + from PIL import Image + + image = Image.open("name_of_your_document - can be a png file, pdf, etc.") + + width, height = image.size + +- For a demo which shows how to fine-tune :class:`LayoutLMForTokenClassification` on the `FUNSD dataset + `__ (a collection of annotated forms), see `this notebook + `__. + It includes an inference part, which shows how to use Google's Tesseract on a new document. The original code can be found `here `_. @@ -78,6 +118,13 @@ LayoutLMForMaskedLM :members: +LayoutLMForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LayoutLMForSequenceClassification + :members: + + LayoutLMForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/notebooks/README.md b/notebooks/README.md index 3cccd51e14c..74e90b1b667 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -76,4 +76,5 @@ Pull Request so it can be included under the Community notebooks. |[Fine-tuning TAPAS on Sequential Question Answering (SQA)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb) | How to fine-tune *TapasForQuestionAnswering* with a *tapas-base* checkpoint on the Sequential Question Answering (SQA) dataset | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb)| |[Evaluating TAPAS on Table Fact Checking (TabFact)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb) | How to evaluate a fine-tuned *TapasForSequenceClassification* with a *tapas-base-finetuned-tabfact* checkpoint using a combination of the 🤗 datasets and 🤗 transformers libraries | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb)| |[Fine-tuning mBART for translation](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb) | How to fine-tune mBART using Seq2SeqTrainer for Hindi to English translation | [Vasudev Gupta](https://github.com/vasudevgupta7) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb)| -[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | How to fine-tune DistilGPT2 and generate text | [Aakash Tripathi](https://github.com/tripathiaakash) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)| +|[Fine-tuning LayoutLM on FUNSD (a form understanding dataset)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb) | How to fine-tune *LayoutLMForTokenClassification* on the FUNSD dataset | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb)| +|[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | How to fine-tune DistilGPT2 and generate text | [Aakash Tripathi](https://github.com/tripathiaakash) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)| diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5ffaa001b50..df435e7431e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -561,6 +561,7 @@ if is_torch_available(): [ "LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST", "LayoutLMForMaskedLM", + "LayoutLMForSequenceClassification", "LayoutLMForTokenClassification", "LayoutLMModel", ] @@ -1597,6 +1598,7 @@ if TYPE_CHECKING: from .models.layoutlm import ( LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, LayoutLMForMaskedLM, + LayoutLMForSequenceClassification, LayoutLMForTokenClassification, LayoutLMModel, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8a2856712eb..8c34ebe5d81 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -101,7 +101,12 @@ from ..funnel.modeling_funnel import ( FunnelModel, ) from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model -from ..layoutlm.modeling_layoutlm import LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel +from ..layoutlm.modeling_layoutlm import ( + LayoutLMForMaskedLM, + LayoutLMForSequenceClassification, + LayoutLMForTokenClassification, + LayoutLMModel, +) from ..led.modeling_led import ( LEDForConditionalGeneration, LEDForQuestionAnswering, @@ -470,6 +475,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (TransfoXLConfig, TransfoXLForSequenceClassification), (MPNetConfig, MPNetForSequenceClassification), (TapasConfig, TapasForSequenceClassification), + (LayoutLMConfig, LayoutLMForSequenceClassification), ] ) diff --git a/src/transformers/models/layoutlm/__init__.py b/src/transformers/models/layoutlm/__init__.py index 2aab097d97e..30825bf0125 100644 --- a/src/transformers/models/layoutlm/__init__.py +++ b/src/transformers/models/layoutlm/__init__.py @@ -33,6 +33,7 @@ if is_torch_available(): _import_structure["modeling_layoutlm"] = [ "LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST", "LayoutLMForMaskedLM", + "LayoutLMForSequenceClassification", "LayoutLMForTokenClassification", "LayoutLMModel", ] @@ -49,6 +50,7 @@ if TYPE_CHECKING: from .modeling_layoutlm import ( LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, LayoutLMForMaskedLM, + LayoutLMForSequenceClassification, LayoutLMForTokenClassification, LayoutLMModel, ) diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 3e3318121a9..d611336fbfc 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -19,14 +19,15 @@ import math import torch from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, + SequenceClassifierOutput, TokenClassifierOutput, ) from ...modeling_utils import ( @@ -596,6 +597,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): """ config_class = LayoutLMConfig + pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlm" _keys_to_ignore_on_load_missing = [r"position_ids"] @@ -614,7 +616,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): LAYOUTLM_START_DOCSTRING = r""" The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding - `__ by.... + `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei and Ming Zhou. This model is a PyTorch `torch.nn.Module `_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and @@ -638,8 +640,10 @@ LAYOUTLM_INPUTS_DOCSTRING = r""" `What are input IDs? <../glossary.html#input-ids>`__ bbox (:obj:`torch.LongTensor` of shape :obj:`({0}, 4)`, `optional`): - Bounding Boxes of each input sequence tokens. Selected in the range ``[0, - config.max_2d_position_embeddings-1]``. + Bounding boxes of each input sequence tokens. Selected in the range ``[0, + config.max_2d_position_embeddings-1]``. Each bounding box should be a normalized version in (x0, y0, x1, + y1) format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and + (x1, y1) represents the position of the lower right corner. See :ref:`Overview` for normalization. attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. @@ -679,11 +683,6 @@ LAYOUTLM_INPUTS_DOCSTRING = r""" LAYOUTLM_START_DOCSTRING, ) class LayoutLMModel(LayoutLMPreTrainedModel): - - config_class = LayoutLMConfig - pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST - base_model_prefix = "layoutlm" - def __init__(self, config): super(LayoutLMModel, self).__init__(config) self.config = config @@ -709,12 +708,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="layoutlm-base-uncased", - output_type=BaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -730,31 +724,36 @@ class LayoutLMModel(LayoutLMPreTrainedModel): output_hidden_states=None, return_dict=None, ): - """ - input_ids (torch.LongTensor of shape (batch_size, sequence_length)): - Indices of input sequence tokens in the vocabulary. - attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): - Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: 1 for tokens - that are NOT MASKED, 0 for MASKED tokens. - token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]: - 0 corresponds to a sentence A token, 1 corresponds to a sentence B token - position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, - config.max_position_embeddings - 1]. - head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional): - Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates - the head is not masked, 0 indicates the head is masked. - inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional): - Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert input_ids indices into associated vectors than the - model’s internal embedding lookup matrix. - output_attentions (bool, optional): - If set to True, the attentions tensors of all attention layers are returned. - output_hidden_states (bool, optional): - If set to True, the hidden states of all layers are returned. - return_dict (bool, optional): - If set to True, the model will return a ModelOutput instead of a plain tuple. + r""" + Returns: + + Examples:: + + >>> from transformers import LayoutLMTokenizer, LayoutLMModel + >>> import torch + + >>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased') + >>> model = LayoutLMModel.from_pretrained('microsoft/layoutlm-base-uncased') + + >>> words = ["Hello", "world"] + >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] + + >>> token_boxes = [] + >>> for word, box in zip(words, normalized_word_boxes): + ... word_tokens = tokenizer.tokenize(word) + ... token_boxes.extend([box] * len(word_tokens)) + >>> # add bounding boxes of cls + sep tokens + >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] + + >>> encoding = tokenizer(' '.join(words), return_tensors="pt") + >>> input_ids = encoding["input_ids"] + >>> attention_mask = encoding["attention_mask"] + >>> token_type_ids = encoding["token_type_ids"] + >>> bbox = torch.tensor([token_boxes]) + + >>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids) + + >>> last_hidden_states = outputs.last_hidden_state """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -828,10 +827,6 @@ class LayoutLMModel(LayoutLMPreTrainedModel): @add_start_docstrings("""LayoutLM Model with a `language modeling` head on top. """, LAYOUTLM_START_DOCSTRING) class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): - config_class = LayoutLMConfig - pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST - base_model_prefix = "layoutlm" - def __init__(self, config): super().__init__(config) @@ -850,12 +845,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): self.cls.predictions.decoder = new_embeddings @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="layoutlm-base-uncased", - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -872,7 +862,45 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): output_hidden_states=None, return_dict=None, ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + Returns: + + Examples:: + + >>> from transformers import LayoutLMTokenizer, LayoutLMForMaskedLM + >>> import torch + + >>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased') + >>> model = LayoutLMForMaskedLM.from_pretrained('microsoft/layoutlm-base-uncased') + + >>> words = ["Hello", "[MASK]"] + >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] + + >>> token_boxes = [] + >>> for word, box in zip(words, normalized_word_boxes): + ... word_tokens = tokenizer.tokenize(word) + ... token_boxes.extend([box] * len(word_tokens)) + >>> # add bounding boxes of cls + sep tokens + >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] + + >>> encoding = tokenizer(' '.join(words), return_tensors="pt") + >>> input_ids = encoding["input_ids"] + >>> attention_mask = encoding["attention_mask"] + >>> token_type_ids = encoding["token_type_ids"] + >>> bbox = torch.tensor([token_boxes]) + + >>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"] + + >>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, + ... labels=labels) + + >>> loss = outputs.loss + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.layoutlm( @@ -915,16 +943,12 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): @add_start_docstrings( """ - LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. + LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for + document image classification tasks such as the `RVL-CDIP `__ dataset. """, LAYOUTLM_START_DOCSTRING, ) -class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): - config_class = LayoutLMConfig - pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST - base_model_prefix = "layoutlm" - +class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -938,12 +962,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): return self.layoutlm.embeddings.word_embeddings @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="layoutlm-base-uncased", - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -958,6 +977,162 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): output_hidden_states=None, return_dict=None, ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification + >>> import torch + + >>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased') + >>> model = LayoutLMForSequenceClassification.from_pretrained('microsoft/layoutlm-base-uncased') + + >>> words = ["Hello", "world"] + >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] + + >>> token_boxes = [] + >>> for word, box in zip(words, normalized_word_boxes): + ... word_tokens = tokenizer.tokenize(word) + ... token_boxes.extend([box] * len(word_tokens)) + >>> # add bounding boxes of cls + sep tokens + >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] + + >>> encoding = tokenizer(' '.join(words), return_tensors="pt") + >>> input_ids = encoding["input_ids"] + >>> attention_mask = encoding["attention_mask"] + >>> token_type_ids = encoding["token_type_ids"] + >>> bbox = torch.tensor([token_boxes]) + >>> sequence_label = torch.tensor([1]) + + >>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, + ... labels=sequence_label) + + >>> loss = outputs.loss + >>> logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.layoutlm( + input_ids=input_ids, + bbox=bbox, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + sequence labeling (information extraction) tasks such as the `FUNSD `__ + dataset and the `SROIE `__ dataset. + """, + LAYOUTLM_START_DOCSTRING, +) +class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.layoutlm = LayoutLMModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def get_input_embeddings(self): + return self.layoutlm.embeddings.word_embeddings + + @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + bbox=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + + Returns: + + Examples:: + + >>> from transformers import LayoutLMTokenizer, LayoutLMForTokenClassification + >>> import torch + + >>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased') + >>> model = LayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased') + + >>> words = ["Hello", "world"] + >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] + + >>> token_boxes = [] + >>> for word, box in zip(words, normalized_word_boxes): + ... word_tokens = tokenizer.tokenize(word) + ... token_boxes.extend([box] * len(word_tokens)) + >>> # add bounding boxes of cls + sep tokens + >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] + + >>> encoding = tokenizer(' '.join(words), return_tensors="pt") + >>> input_ids = encoding["input_ids"] + >>> attention_mask = encoding["attention_mask"] + >>> token_type_ids = encoding["token_type_ids"] + >>> bbox = torch.tensor([token_boxes]) + >>> token_labels = torch.tensor([1,1,0,0]).unsqueeze(0) # batch size of 1 + + >>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, + ... labels=token_labels) + + >>> loss = outputs.loss + >>> logits = outputs.logits + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.layoutlm( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 072ab5c8808..bbc444d3d2a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1182,6 +1182,15 @@ class LayoutLMForMaskedLM: requires_pytorch(self) +class LayoutLMForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class LayoutLMForTokenClassification: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_layoutlm.py b/tests/test_modeling_layoutlm.py index ff0eb7ef02b..2d64b55d8a1 100644 --- a/tests/test_modeling_layoutlm.py +++ b/tests/test_modeling_layoutlm.py @@ -17,19 +17,26 @@ import unittest from transformers import is_torch_available -from transformers.file_utils import cached_property -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor if is_torch_available(): - from transformers import LayoutLMConfig, LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel + import torch + + from transformers import ( + LayoutLMConfig, + LayoutLMForMaskedLM, + LayoutLMForSequenceClassification, + LayoutLMForTokenClassification, + LayoutLMModel, + ) class LayoutLMModelTester: - """You can also import this e.g from .test_modeling_bart import BartModelTester """ + """You can also import this e.g from .test_modeling_layoutlm import LayoutLMModelTester """ def __init__( self, @@ -150,6 +157,18 @@ class LayoutLMModelTester: result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_for_sequence_classification( + self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = LayoutLMForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def create_and_check_for_token_classification( self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -185,7 +204,14 @@ class LayoutLMModelTester: class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (LayoutLMModel, LayoutLMForMaskedLM, LayoutLMForTokenClassification) if is_torch_available() else () + ( + LayoutLMModel, + LayoutLMForMaskedLM, + LayoutLMForSequenceClassification, + LayoutLMForTokenClassification, + ) + if is_torch_available() + else None ) def setUp(self): @@ -209,36 +235,101 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) - @cached_property - def big_model(self): - """Cached property means this code will only be executed once.""" - checkpoint_path = "microsoft/layoutlm-large-uncased" - model = LayoutLMForMaskedLM.from_pretrained(checkpoint_path).to( - torch_device - ) # test whether AutoModel can determine your model_class from checkpoint name - if torch_device == "cuda": - model.half() - # optional: do more testing! This will save you time later! +def prepare_layoutlm_batch_inputs(): + # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: + # fmt: off + input_ids = torch.tensor([[-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-16.2628059,-10004.082,15.4330549,15.4330549,15.4330549,-9990.42,-16.3270779,-16.3270779,-16.3270779,-16.3270779,-16.3270779,-10004.8506]],device=torch_device) # noqa: E231 + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],],device=torch_device) # noqa: E231 + bbox = torch.tensor([[[0,0,0,0],[423,237,440,251],[427,272,441,287],[419,115,437,129],[961,885,992,912],[256,38,330,58],[256,38,330,58],[336,42,353,57],[360,39,401,56],[360,39,401,56],[411,39,471,59],[479,41,528,59],[533,39,630,60],[67,113,134,131],[141,115,209,132],[68,149,133,166],[141,149,187,164],[195,148,287,165],[195,148,287,165],[195,148,287,165],[295,148,349,165],[441,149,492,166],[497,149,546,164],[64,201,125,218],[1000,1000,1000,1000]],[[0,0,0,0],[662,150,754,166],[665,199,742,211],[519,213,554,228],[519,213,554,228],[134,433,187,454],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[314,469,376,482],[504,684,582,706],[941,825,973,900],[941,825,973,900],[941,825,973,900],[941,825,973,900],[610,749,652,765],[130,659,168,672],[176,657,237,672],[238,657,312,672],[443,653,628,672],[443,653,628,672],[716,301,825,317],[1000,1000,1000,1000]]],device=torch_device) # noqa: E231 + token_type_ids = torch.tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]],device=torch_device) # noqa: E231 + # these are sequence labels (i.e. at the token level) + labels = torch.tensor([[-100,10,10,10,9,1,-100,7,7,-100,7,7,4,2,5,2,8,8,-100,-100,5,0,3,2,-100],[-100,12,12,12,-100,12,10,-100,-100,-100,-100,10,12,9,-100,-100,-100,10,10,10,9,12,-100,10,-100]],device=torch_device) # noqa: E231 + # fmt: on + + return input_ids, attention_mask, bbox, token_type_ids, labels + + +@require_torch +class LayoutLMModelIntegrationTest(unittest.TestCase): @slow - def test_that_LayoutLM_can_be_used_in_a_pipeline(self): - """We can use self.big_model here without calling __init__ again.""" - pass + def test_forward_pass_no_head(self): + model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased").to(torch_device) - def test_LayoutLM_loss_doesnt_change_if_you_add_padding(self): - pass + input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs() - def test_LayoutLM_bad_args(self): - pass + # forward pass + outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids) - def test_LayoutLM_backward_pass_reduces_loss(self): - """Test loss/gradients same as reference implementation, for example.""" - pass + # test the sequence output on [0, :3, :3] + expected_slice = torch.tensor( + [[0.1785, -0.1947, -0.0425], [-0.3254, -0.2807, 0.2553], [-0.5391, -0.3322, 0.3364]], + device=torch_device, + ) - @require_torch_gpu - def test_large_inputs_in_fp16_dont_cause_overflow(self): - pass + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-3)) + + # test the pooled output on [1, :3] + expected_slice = torch.tensor([-0.6580, -0.0214, 0.8552], device=torch_device) + + self.assertTrue(torch.allclose(outputs.pooler_output[1, :3], expected_slice, atol=1e-3)) + + @slow + def test_forward_pass_sequence_classification(self): + # initialize model with randomly initialized sequence classification head + model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=2).to( + torch_device + ) + + input_ids, attention_mask, bbox, token_type_ids, _ = prepare_layoutlm_batch_inputs() + + # forward pass + outputs = model( + input_ids=input_ids, + bbox=bbox, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + labels=torch.tensor([1, 1], device=torch_device), + ) + + # test whether we get a loss as a scalar + loss = outputs.loss + expected_shape = torch.Size([]) + self.assertEqual(loss.shape, expected_shape) + + # test the shape of the logits + logits = outputs.logits + expected_shape = torch.Size((2, 2)) + self.assertEqual(logits.shape, expected_shape) + + @slow + def test_forward_pass_token_classification(self): + # initialize model with randomly initialized token classification head + model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=13).to( + torch_device + ) + + input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs() + + # forward pass + outputs = model( + input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels + ) + + # test the loss calculation to be around 2.65 + expected_loss = torch.tensor(2.65, device=torch_device) + + self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=0.1)) + + # test the shape of the logits + logits = outputs.logits + expected_shape = torch.Size((2, 25, 13)) + self.assertEqual(logits.shape, expected_shape)