Improve LayoutLM (#9476)

* Add LayoutLMForSequenceClassification and integration tests

Improve docs

Add LayoutLM notebook to list of community notebooks

* Make style & quality

* Address comments by @sgugger, @patrickvonplaten and @LysandreJik

* Fix rebase with master

* Reformat in one line

* Improve code examples as requested by @patrickvonplaten

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
NielsRogge 2021-01-12 15:26:32 +01:00 committed by GitHub
parent ccd1923f46
commit e45eba3b1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 438 additions and 105 deletions

View File

@ -13,32 +13,72 @@
LayoutLM
-----------------------------------------------------------------------------------------------------------------------
.. _Overview:
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The LayoutLM model was proposed in the paper `LayoutLM: Pre-training of Text and Layout for Document Image
Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, and
Ming Zhou. It's a simple but effective pretraining method of text and layout for document image understanding and
information extraction tasks, such as form understanding and receipt understanding.
information extraction tasks, such as form understanding and receipt understanding. It obtains state-of-the-art results
on several downstream tasks:
- form understanding: the `FUNSD <https://guillaumejaume.github.io/FUNSD/>`__ dataset (a collection of 199 annotated
forms comprising more than 30,000 words).
- receipt understanding: the `SROIE <https://rrc.cvc.uab.es/?ch=13>`__ dataset (a collection of 626 receipts for
training and 347 receipts for testing).
- document image classification: the `RVL-CDIP <https://www.cs.cmu.edu/~aharley/rvl-cdip/>`__ dataset (a collection of
400,000 images belonging to one of 16 classes).
The abstract from the paper is the following:
*Pre-training techniques have been verified successfully in a variety of NLP tasks in recent years. Despite the
widespread use of pretraining models for NLP applications, they almost exclusively focus on text-level manipulation,
while neglecting layout and style information that is vital for document image understanding. In this paper, we propose
the \textbf{LayoutLM} to jointly model interactions between text and layout information across scanned document images,
which is beneficial for a great number of real-world document image understanding tasks such as information extraction
from scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into
LayoutLM. To the best of our knowledge, this is the first time that text and layout are jointly learned in a single
framework for document-level pretraining. It achieves new state-of-the-art results in several downstream tasks,
including form understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image
classification (from 93.07 to 94.42).*
the LayoutLM to jointly model interactions between text and layout information across scanned document images, which is
beneficial for a great number of real-world document image understanding tasks such as information extraction from
scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into LayoutLM.
To the best of our knowledge, this is the first time that text and layout are jointly learned in a single framework for
document-level pretraining. It achieves new state-of-the-art results in several downstream tasks, including form
understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image classification
(from 93.07 to 94.42).*
Tips:
- LayoutLM has an extra input called :obj:`bbox`, which is the bounding boxes of the input tokens.
- The :obj:`bbox` requires the data that on 0-1000 scale, which means you should normalize the bounding box before
passing them into model.
- In addition to `input_ids`, :meth:`~transformer.LayoutLMModel.forward` also expects the input :obj:`bbox`, which are
the bounding boxes (i.e. 2D-positions) of the input tokens. These can be obtained using an external OCR engine such
as Google's `Tesseract <https://github.com/tesseract-ocr/tesseract>`__ (there's a `Python wrapper
<https://pypi.org/project/pytesseract/>`__ available). Each bounding box should be in (x0, y0, x1, y1) format, where
(x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, y1) represents the
position of the lower right corner. Note that one first needs to normalize the bounding boxes to be on a 0-1000
scale. To normalize, you can use the following function:
.. code-block::
def normalize_bbox(bbox, width, height):
return [
int(1000 * (bbox[0] / width)),
int(1000 * (bbox[1] / height)),
int(1000 * (bbox[2] / width)),
int(1000 * (bbox[3] / height)),
]
Here, :obj:`width` and :obj:`height` correspond to the width and height of the original document in which the token
occurs. Those can be obtained using the Python Image Library (PIL) library for example, as follows:
.. code-block::
from PIL import Image
image = Image.open("name_of_your_document - can be a png file, pdf, etc.")
width, height = image.size
- For a demo which shows how to fine-tune :class:`LayoutLMForTokenClassification` on the `FUNSD dataset
<https://guillaumejaume.github.io/FUNSD/>`__ (a collection of annotated forms), see `this notebook
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb>`__.
It includes an inference part, which shows how to use Google's Tesseract on a new document.
The original code can be found `here <https://github.com/microsoft/unilm/tree/master/layoutlm>`_.
@ -78,6 +118,13 @@ LayoutLMForMaskedLM
:members:
LayoutLMForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LayoutLMForSequenceClassification
:members:
LayoutLMForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -76,4 +76,5 @@ Pull Request so it can be included under the Community notebooks.
|[Fine-tuning TAPAS on Sequential Question Answering (SQA)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb) | How to fine-tune *TapasForQuestionAnswering* with a *tapas-base* checkpoint on the Sequential Question Answering (SQA) dataset | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb)|
|[Evaluating TAPAS on Table Fact Checking (TabFact)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb) | How to evaluate a fine-tuned *TapasForSequenceClassification* with a *tapas-base-finetuned-tabfact* checkpoint using a combination of the 🤗 datasets and 🤗 transformers libraries | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb)|
|[Fine-tuning mBART for translation](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb) | How to fine-tune mBART using Seq2SeqTrainer for Hindi to English translation | [Vasudev Gupta](https://github.com/vasudevgupta7) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb)|
[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | How to fine-tune DistilGPT2 and generate text | [Aakash Tripathi](https://github.com/tripathiaakash) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)|
|[Fine-tuning LayoutLM on FUNSD (a form understanding dataset)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb) | How to fine-tune *LayoutLMForTokenClassification* on the FUNSD dataset | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb)|
|[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | How to fine-tune DistilGPT2 and generate text | [Aakash Tripathi](https://github.com/tripathiaakash) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)|

View File

@ -561,6 +561,7 @@ if is_torch_available():
[
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMForMaskedLM",
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
]
@ -1597,6 +1598,7 @@ if TYPE_CHECKING:
from .models.layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
)

View File

@ -101,7 +101,12 @@ from ..funnel.modeling_funnel import (
FunnelModel,
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..layoutlm.modeling_layoutlm import LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
from ..layoutlm.modeling_layoutlm import (
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
)
from ..led.modeling_led import (
LEDForConditionalGeneration,
LEDForQuestionAnswering,
@ -470,6 +475,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(TransfoXLConfig, TransfoXLForSequenceClassification),
(MPNetConfig, MPNetForSequenceClassification),
(TapasConfig, TapasForSequenceClassification),
(LayoutLMConfig, LayoutLMForSequenceClassification),
]
)

View File

@ -33,6 +33,7 @@ if is_torch_available():
_import_structure["modeling_layoutlm"] = [
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMForMaskedLM",
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
]
@ -49,6 +50,7 @@ if TYPE_CHECKING:
from .modeling_layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
)

View File

@ -19,14 +19,15 @@ import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
@ -596,6 +597,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
"""
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
_keys_to_ignore_on_load_missing = [r"position_ids"]
@ -614,7 +616,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
LAYOUTLM_START_DOCSTRING = r"""
The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding
<https://arxiv.org/abs/1912.13318>`__ by....
<https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei and Ming Zhou.
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
@ -638,8 +640,10 @@ LAYOUTLM_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
bbox (:obj:`torch.LongTensor` of shape :obj:`({0}, 4)`, `optional`):
Bounding Boxes of each input sequence tokens. Selected in the range ``[0,
config.max_2d_position_embeddings-1]``.
Bounding boxes of each input sequence tokens. Selected in the range ``[0,
config.max_2d_position_embeddings-1]``. Each bounding box should be a normalized version in (x0, y0, x1,
y1) format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and
(x1, y1) represents the position of the lower right corner. See :ref:`Overview` for normalization.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for
tokens that are NOT MASKED, ``0`` for MASKED tokens.
@ -679,11 +683,6 @@ LAYOUTLM_INPUTS_DOCSTRING = r"""
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMModel(LayoutLMPreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
def __init__(self, config):
super(LayoutLMModel, self).__init__(config)
self.config = config
@ -709,12 +708,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -730,31 +724,36 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
"""
input_ids (torch.LongTensor of shape (batch_size, sequence_length)):
Indices of input sequence tokens in the vocabulary.
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional):
Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: 1 for tokens
that are NOT MASKED, 0 for MASKED tokens.
token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]:
0 corresponds to a sentence A token, 1 corresponds to a sentence B token
position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0,
config.max_position_embeddings - 1].
head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional):
Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates
the head is not masked, 0 indicates the head is masked.
inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional):
Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert input_ids indices into associated vectors than the
models internal embedding lookup matrix.
output_attentions (bool, optional):
If set to True, the attentions tensors of all attention layers are returned.
output_hidden_states (bool, optional):
If set to True, the hidden states of all layers are returned.
return_dict (bool, optional):
If set to True, the model will return a ModelOutput instead of a plain tuple.
r"""
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, LayoutLMModel
>>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = LayoutLMModel.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "world"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = torch.tensor([token_boxes])
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -828,10 +827,6 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top. """, LAYOUTLM_START_DOCSTRING)
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
def __init__(self, config):
super().__init__(config)
@ -850,12 +845,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -872,7 +862,45 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, LayoutLMForMaskedLM
>>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = LayoutLMForMaskedLM.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "[MASK]"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = torch.tensor([token_boxes])
>>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"]
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=labels)
>>> loss = outputs.loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlm(
@ -915,16 +943,12 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
@add_start_docstrings(
"""
LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for
document image classification tasks such as the `RVL-CDIP <https://www.cs.cmu.edu/~aharley/rvl-cdip/>`__ dataset.
""",
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
@ -938,12 +962,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
return self.layoutlm.embeddings.word_embeddings
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -958,6 +977,162 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification
>>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = LayoutLMForSequenceClassification.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "world"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = torch.tensor([token_boxes])
>>> sequence_label = torch.tensor([1])
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=sequence_label)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
sequence labeling (information extraction) tasks such as the `FUNSD <https://guillaumejaume.github.io/FUNSD/>`__
dataset and the `SROIE <https://rrc.cvc.uab.es/?ch=13>`__ dataset.
""",
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.layoutlm = LayoutLMModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
bbox=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, LayoutLMForTokenClassification
>>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = LayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "world"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = torch.tensor([token_boxes])
>>> token_labels = torch.tensor([1,1,0,0]).unsqueeze(0) # batch size of 1
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=token_labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlm(

View File

@ -1182,6 +1182,15 @@ class LayoutLMForMaskedLM:
requires_pytorch(self)
class LayoutLMForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class LayoutLMForTokenClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)

View File

@ -17,19 +17,26 @@
import unittest
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
from transformers import LayoutLMConfig, LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
import torch
from transformers import (
LayoutLMConfig,
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
)
class LayoutLMModelTester:
"""You can also import this e.g from .test_modeling_bart import BartModelTester """
"""You can also import this e.g from .test_modeling_layoutlm import LayoutLMModelTester """
def __init__(
self,
@ -150,6 +157,18 @@ class LayoutLMModelTester:
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = LayoutLMForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(
input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -185,7 +204,14 @@ class LayoutLMModelTester:
class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(LayoutLMModel, LayoutLMForMaskedLM, LayoutLMForTokenClassification) if is_torch_available() else ()
(
LayoutLMModel,
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
)
if is_torch_available()
else None
)
def setUp(self):
@ -209,36 +235,101 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@cached_property
def big_model(self):
"""Cached property means this code will only be executed once."""
checkpoint_path = "microsoft/layoutlm-large-uncased"
model = LayoutLMForMaskedLM.from_pretrained(checkpoint_path).to(
torch_device
) # test whether AutoModel can determine your model_class from checkpoint name
if torch_device == "cuda":
model.half()
# optional: do more testing! This will save you time later!
def prepare_layoutlm_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
# fmt: off
input_ids = torch.tensor([[-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-16.2628059,-10004.082,15.4330549,15.4330549,15.4330549,-9990.42,-16.3270779,-16.3270779,-16.3270779,-16.3270779,-16.3270779,-10004.8506]],device=torch_device) # noqa: E231
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],],device=torch_device) # noqa: E231
bbox = torch.tensor([[[0,0,0,0],[423,237,440,251],[427,272,441,287],[419,115,437,129],[961,885,992,912],[256,38,330,58],[256,38,330,58],[336,42,353,57],[360,39,401,56],[360,39,401,56],[411,39,471,59],[479,41,528,59],[533,39,630,60],[67,113,134,131],[141,115,209,132],[68,149,133,166],[141,149,187,164],[195,148,287,165],[195,148,287,165],[195,148,287,165],[295,148,349,165],[441,149,492,166],[497,149,546,164],[64,201,125,218],[1000,1000,1000,1000]],[[0,0,0,0],[662,150,754,166],[665,199,742,211],[519,213,554,228],[519,213,554,228],[134,433,187,454],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[314,469,376,482],[504,684,582,706],[941,825,973,900],[941,825,973,900],[941,825,973,900],[941,825,973,900],[610,749,652,765],[130,659,168,672],[176,657,237,672],[238,657,312,672],[443,653,628,672],[443,653,628,672],[716,301,825,317],[1000,1000,1000,1000]]],device=torch_device) # noqa: E231
token_type_ids = torch.tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]],device=torch_device) # noqa: E231
# these are sequence labels (i.e. at the token level)
labels = torch.tensor([[-100,10,10,10,9,1,-100,7,7,-100,7,7,4,2,5,2,8,8,-100,-100,5,0,3,2,-100],[-100,12,12,12,-100,12,10,-100,-100,-100,-100,10,12,9,-100,-100,-100,10,10,10,9,12,-100,10,-100]],device=torch_device) # noqa: E231
# fmt: on
return input_ids, attention_mask, bbox, token_type_ids, labels
@require_torch
class LayoutLMModelIntegrationTest(unittest.TestCase):
@slow
def test_that_LayoutLM_can_be_used_in_a_pipeline(self):
"""We can use self.big_model here without calling __init__ again."""
pass
def test_forward_pass_no_head(self):
model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased").to(torch_device)
def test_LayoutLM_loss_doesnt_change_if_you_add_padding(self):
pass
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
def test_LayoutLM_bad_args(self):
pass
# forward pass
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
def test_LayoutLM_backward_pass_reduces_loss(self):
"""Test loss/gradients same as reference implementation, for example."""
pass
# test the sequence output on [0, :3, :3]
expected_slice = torch.tensor(
[[0.1785, -0.1947, -0.0425], [-0.3254, -0.2807, 0.2553], [-0.5391, -0.3322, 0.3364]],
device=torch_device,
)
@require_torch_gpu
def test_large_inputs_in_fp16_dont_cause_overflow(self):
pass
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-3))
# test the pooled output on [1, :3]
expected_slice = torch.tensor([-0.6580, -0.0214, 0.8552], device=torch_device)
self.assertTrue(torch.allclose(outputs.pooler_output[1, :3], expected_slice, atol=1e-3))
@slow
def test_forward_pass_sequence_classification(self):
# initialize model with randomly initialized sequence classification head
model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=2).to(
torch_device
)
input_ids, attention_mask, bbox, token_type_ids, _ = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=torch.tensor([1, 1], device=torch_device),
)
# test whether we get a loss as a scalar
loss = outputs.loss
expected_shape = torch.Size([])
self.assertEqual(loss.shape, expected_shape)
# test the shape of the logits
logits = outputs.logits
expected_shape = torch.Size((2, 2))
self.assertEqual(logits.shape, expected_shape)
@slow
def test_forward_pass_token_classification(self):
# initialize model with randomly initialized token classification head
model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=13).to(
torch_device
)
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(
input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels
)
# test the loss calculation to be around 2.65
expected_loss = torch.tensor(2.65, device=torch_device)
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=0.1))
# test the shape of the logits
logits = outputs.logits
expected_shape = torch.Size((2, 25, 13))
self.assertEqual(logits.shape, expected_shape)