mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Improve LayoutLM (#9476)
* Add LayoutLMForSequenceClassification and integration tests Improve docs Add LayoutLM notebook to list of community notebooks * Make style & quality * Address comments by @sgugger, @patrickvonplaten and @LysandreJik * Fix rebase with master * Reformat in one line * Improve code examples as requested by @patrickvonplaten Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
ccd1923f46
commit
e45eba3b1c
@ -13,32 +13,72 @@
|
||||
LayoutLM
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. _Overview:
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The LayoutLM model was proposed in the paper `LayoutLM: Pre-training of Text and Layout for Document Image
|
||||
Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, and
|
||||
Ming Zhou. It's a simple but effective pretraining method of text and layout for document image understanding and
|
||||
information extraction tasks, such as form understanding and receipt understanding.
|
||||
information extraction tasks, such as form understanding and receipt understanding. It obtains state-of-the-art results
|
||||
on several downstream tasks:
|
||||
|
||||
- form understanding: the `FUNSD <https://guillaumejaume.github.io/FUNSD/>`__ dataset (a collection of 199 annotated
|
||||
forms comprising more than 30,000 words).
|
||||
- receipt understanding: the `SROIE <https://rrc.cvc.uab.es/?ch=13>`__ dataset (a collection of 626 receipts for
|
||||
training and 347 receipts for testing).
|
||||
- document image classification: the `RVL-CDIP <https://www.cs.cmu.edu/~aharley/rvl-cdip/>`__ dataset (a collection of
|
||||
400,000 images belonging to one of 16 classes).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Pre-training techniques have been verified successfully in a variety of NLP tasks in recent years. Despite the
|
||||
widespread use of pretraining models for NLP applications, they almost exclusively focus on text-level manipulation,
|
||||
while neglecting layout and style information that is vital for document image understanding. In this paper, we propose
|
||||
the \textbf{LayoutLM} to jointly model interactions between text and layout information across scanned document images,
|
||||
which is beneficial for a great number of real-world document image understanding tasks such as information extraction
|
||||
from scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into
|
||||
LayoutLM. To the best of our knowledge, this is the first time that text and layout are jointly learned in a single
|
||||
framework for document-level pretraining. It achieves new state-of-the-art results in several downstream tasks,
|
||||
including form understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image
|
||||
classification (from 93.07 to 94.42).*
|
||||
the LayoutLM to jointly model interactions between text and layout information across scanned document images, which is
|
||||
beneficial for a great number of real-world document image understanding tasks such as information extraction from
|
||||
scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into LayoutLM.
|
||||
To the best of our knowledge, this is the first time that text and layout are jointly learned in a single framework for
|
||||
document-level pretraining. It achieves new state-of-the-art results in several downstream tasks, including form
|
||||
understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image classification
|
||||
(from 93.07 to 94.42).*
|
||||
|
||||
Tips:
|
||||
|
||||
- LayoutLM has an extra input called :obj:`bbox`, which is the bounding boxes of the input tokens.
|
||||
- The :obj:`bbox` requires the data that on 0-1000 scale, which means you should normalize the bounding box before
|
||||
passing them into model.
|
||||
- In addition to `input_ids`, :meth:`~transformer.LayoutLMModel.forward` also expects the input :obj:`bbox`, which are
|
||||
the bounding boxes (i.e. 2D-positions) of the input tokens. These can be obtained using an external OCR engine such
|
||||
as Google's `Tesseract <https://github.com/tesseract-ocr/tesseract>`__ (there's a `Python wrapper
|
||||
<https://pypi.org/project/pytesseract/>`__ available). Each bounding box should be in (x0, y0, x1, y1) format, where
|
||||
(x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, y1) represents the
|
||||
position of the lower right corner. Note that one first needs to normalize the bounding boxes to be on a 0-1000
|
||||
scale. To normalize, you can use the following function:
|
||||
|
||||
.. code-block::
|
||||
|
||||
def normalize_bbox(bbox, width, height):
|
||||
return [
|
||||
int(1000 * (bbox[0] / width)),
|
||||
int(1000 * (bbox[1] / height)),
|
||||
int(1000 * (bbox[2] / width)),
|
||||
int(1000 * (bbox[3] / height)),
|
||||
]
|
||||
|
||||
Here, :obj:`width` and :obj:`height` correspond to the width and height of the original document in which the token
|
||||
occurs. Those can be obtained using the Python Image Library (PIL) library for example, as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
- For a demo which shows how to fine-tune :class:`LayoutLMForTokenClassification` on the `FUNSD dataset
|
||||
<https://guillaumejaume.github.io/FUNSD/>`__ (a collection of annotated forms), see `this notebook
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb>`__.
|
||||
It includes an inference part, which shows how to use Google's Tesseract on a new document.
|
||||
|
||||
The original code can be found `here <https://github.com/microsoft/unilm/tree/master/layoutlm>`_.
|
||||
|
||||
@ -78,6 +118,13 @@ LayoutLMForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -76,4 +76,5 @@ Pull Request so it can be included under the Community notebooks.
|
||||
|[Fine-tuning TAPAS on Sequential Question Answering (SQA)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb) | How to fine-tune *TapasForQuestionAnswering* with a *tapas-base* checkpoint on the Sequential Question Answering (SQA) dataset | [Niels Rogge](https://github.com/nielsrogge) | [](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb)|
|
||||
|[Evaluating TAPAS on Table Fact Checking (TabFact)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb) | How to evaluate a fine-tuned *TapasForSequenceClassification* with a *tapas-base-finetuned-tabfact* checkpoint using a combination of the 🤗 datasets and 🤗 transformers libraries | [Niels Rogge](https://github.com/nielsrogge) | [](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb)|
|
||||
|[Fine-tuning mBART for translation](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb) | How to fine-tune mBART using Seq2SeqTrainer for Hindi to English translation | [Vasudev Gupta](https://github.com/vasudevgupta7) | [](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb)|
|
||||
[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | How to fine-tune DistilGPT2 and generate text | [Aakash Tripathi](https://github.com/tripathiaakash) | [](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)|
|
||||
|[Fine-tuning LayoutLM on FUNSD (a form understanding dataset)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb) | How to fine-tune *LayoutLMForTokenClassification* on the FUNSD dataset | [Niels Rogge](https://github.com/nielsrogge) | [](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb)|
|
||||
|[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | How to fine-tune DistilGPT2 and generate text | [Aakash Tripathi](https://github.com/tripathiaakash) | [](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)|
|
||||
|
@ -561,6 +561,7 @@ if is_torch_available():
|
||||
[
|
||||
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LayoutLMForMaskedLM",
|
||||
"LayoutLMForSequenceClassification",
|
||||
"LayoutLMForTokenClassification",
|
||||
"LayoutLMModel",
|
||||
]
|
||||
@ -1597,6 +1598,7 @@ if TYPE_CHECKING:
|
||||
from .models.layoutlm import (
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
)
|
||||
|
@ -101,7 +101,12 @@ from ..funnel.modeling_funnel import (
|
||||
FunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
|
||||
from ..layoutlm.modeling_layoutlm import LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
|
||||
from ..layoutlm.modeling_layoutlm import (
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
)
|
||||
from ..led.modeling_led import (
|
||||
LEDForConditionalGeneration,
|
||||
LEDForQuestionAnswering,
|
||||
@ -470,6 +475,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(TransfoXLConfig, TransfoXLForSequenceClassification),
|
||||
(MPNetConfig, MPNetForSequenceClassification),
|
||||
(TapasConfig, TapasForSequenceClassification),
|
||||
(LayoutLMConfig, LayoutLMForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -33,6 +33,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_layoutlm"] = [
|
||||
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LayoutLMForMaskedLM",
|
||||
"LayoutLMForSequenceClassification",
|
||||
"LayoutLMForTokenClassification",
|
||||
"LayoutLMModel",
|
||||
]
|
||||
@ -49,6 +50,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_layoutlm import (
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
)
|
||||
|
@ -19,14 +19,15 @@ import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
@ -596,6 +597,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = LayoutLMConfig
|
||||
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
base_model_prefix = "layoutlm"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
@ -614,7 +616,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
|
||||
|
||||
LAYOUTLM_START_DOCSTRING = r"""
|
||||
The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding
|
||||
<https://arxiv.org/abs/1912.13318>`__ by....
|
||||
<https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei and Ming Zhou.
|
||||
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
|
||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
@ -638,8 +640,10 @@ LAYOUTLM_INPUTS_DOCSTRING = r"""
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
bbox (:obj:`torch.LongTensor` of shape :obj:`({0}, 4)`, `optional`):
|
||||
Bounding Boxes of each input sequence tokens. Selected in the range ``[0,
|
||||
config.max_2d_position_embeddings-1]``.
|
||||
Bounding boxes of each input sequence tokens. Selected in the range ``[0,
|
||||
config.max_2d_position_embeddings-1]``. Each bounding box should be a normalized version in (x0, y0, x1,
|
||||
y1) format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and
|
||||
(x1, y1) represents the position of the lower right corner. See :ref:`Overview` for normalization.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for
|
||||
tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
@ -679,11 +683,6 @@ LAYOUTLM_INPUTS_DOCSTRING = r"""
|
||||
LAYOUTLM_START_DOCSTRING,
|
||||
)
|
||||
class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
|
||||
config_class = LayoutLMConfig
|
||||
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
base_model_prefix = "layoutlm"
|
||||
|
||||
def __init__(self, config):
|
||||
super(LayoutLMModel, self).__init__(config)
|
||||
self.config = config
|
||||
@ -709,12 +708,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="layoutlm-base-uncased",
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -730,31 +724,36 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
"""
|
||||
input_ids (torch.LongTensor of shape (batch_size, sequence_length)):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: 1 for tokens
|
||||
that are NOT MASKED, 0 for MASKED tokens.
|
||||
token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]:
|
||||
0 corresponds to a sentence A token, 1 corresponds to a sentence B token
|
||||
position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0,
|
||||
config.max_position_embeddings - 1].
|
||||
head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates
|
||||
the head is not masked, 0 indicates the head is masked.
|
||||
inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional):
|
||||
Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert input_ids indices into associated vectors than the
|
||||
model’s internal embedding lookup matrix.
|
||||
output_attentions (bool, optional):
|
||||
If set to True, the attentions tensors of all attention layers are returned.
|
||||
output_hidden_states (bool, optional):
|
||||
If set to True, the hidden states of all layers are returned.
|
||||
return_dict (bool, optional):
|
||||
If set to True, the model will return a ModelOutput instead of a plain tuple.
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import LayoutLMTokenizer, LayoutLMModel
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
>>> model = LayoutLMModel.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
|
||||
>>> words = ["Hello", "world"]
|
||||
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
|
||||
|
||||
>>> token_boxes = []
|
||||
>>> for word, box in zip(words, normalized_word_boxes):
|
||||
... word_tokens = tokenizer.tokenize(word)
|
||||
... token_boxes.extend([box] * len(word_tokens))
|
||||
>>> # add bounding boxes of cls + sep tokens
|
||||
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
|
||||
|
||||
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
|
||||
>>> input_ids = encoding["input_ids"]
|
||||
>>> attention_mask = encoding["attention_mask"]
|
||||
>>> token_type_ids = encoding["token_type_ids"]
|
||||
>>> bbox = torch.tensor([token_boxes])
|
||||
|
||||
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -828,10 +827,6 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top. """, LAYOUTLM_START_DOCSTRING)
|
||||
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
config_class = LayoutLMConfig
|
||||
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
base_model_prefix = "layoutlm"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -850,12 +845,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="layoutlm-base-uncased",
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -872,7 +862,45 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import LayoutLMTokenizer, LayoutLMForMaskedLM
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
>>> model = LayoutLMForMaskedLM.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
|
||||
>>> words = ["Hello", "[MASK]"]
|
||||
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
|
||||
|
||||
>>> token_boxes = []
|
||||
>>> for word, box in zip(words, normalized_word_boxes):
|
||||
... word_tokens = tokenizer.tokenize(word)
|
||||
... token_boxes.extend([box] * len(word_tokens))
|
||||
>>> # add bounding boxes of cls + sep tokens
|
||||
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
|
||||
|
||||
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
|
||||
>>> input_ids = encoding["input_ids"]
|
||||
>>> attention_mask = encoding["attention_mask"]
|
||||
>>> token_type_ids = encoding["token_type_ids"]
|
||||
>>> bbox = torch.tensor([token_boxes])
|
||||
|
||||
>>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"]
|
||||
|
||||
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
|
||||
... labels=labels)
|
||||
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.layoutlm(
|
||||
@ -915,16 +943,12 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
Named-Entity-Recognition (NER) tasks.
|
||||
LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for
|
||||
document image classification tasks such as the `RVL-CDIP <https://www.cs.cmu.edu/~aharley/rvl-cdip/>`__ dataset.
|
||||
""",
|
||||
LAYOUTLM_START_DOCSTRING,
|
||||
)
|
||||
class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
config_class = LayoutLMConfig
|
||||
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
base_model_prefix = "layoutlm"
|
||||
|
||||
class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@ -938,12 +962,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="layoutlm-base-uncased",
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -958,6 +977,162 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
>>> model = LayoutLMForSequenceClassification.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
|
||||
>>> words = ["Hello", "world"]
|
||||
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
|
||||
|
||||
>>> token_boxes = []
|
||||
>>> for word, box in zip(words, normalized_word_boxes):
|
||||
... word_tokens = tokenizer.tokenize(word)
|
||||
... token_boxes.extend([box] * len(word_tokens))
|
||||
>>> # add bounding boxes of cls + sep tokens
|
||||
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
|
||||
|
||||
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
|
||||
>>> input_ids = encoding["input_ids"]
|
||||
>>> attention_mask = encoding["attention_mask"]
|
||||
>>> token_type_ids = encoding["token_type_ids"]
|
||||
>>> bbox = torch.tensor([token_boxes])
|
||||
>>> sequence_label = torch.tensor([1])
|
||||
|
||||
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
|
||||
... labels=sequence_label)
|
||||
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.layoutlm(
|
||||
input_ids=input_ids,
|
||||
bbox=bbox,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
sequence labeling (information extraction) tasks such as the `FUNSD <https://guillaumejaume.github.io/FUNSD/>`__
|
||||
dataset and the `SROIE <https://rrc.cvc.uab.es/?ch=13>`__ dataset.
|
||||
""",
|
||||
LAYOUTLM_START_DOCSTRING,
|
||||
)
|
||||
class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.layoutlm = LayoutLMModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
bbox=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import LayoutLMTokenizer, LayoutLMForTokenClassification
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
>>> model = LayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased')
|
||||
|
||||
>>> words = ["Hello", "world"]
|
||||
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
|
||||
|
||||
>>> token_boxes = []
|
||||
>>> for word, box in zip(words, normalized_word_boxes):
|
||||
... word_tokens = tokenizer.tokenize(word)
|
||||
... token_boxes.extend([box] * len(word_tokens))
|
||||
>>> # add bounding boxes of cls + sep tokens
|
||||
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
|
||||
|
||||
>>> encoding = tokenizer(' '.join(words), return_tensors="pt")
|
||||
>>> input_ids = encoding["input_ids"]
|
||||
>>> attention_mask = encoding["attention_mask"]
|
||||
>>> token_type_ids = encoding["token_type_ids"]
|
||||
>>> bbox = torch.tensor([token_boxes])
|
||||
>>> token_labels = torch.tensor([1,1,0,0]).unsqueeze(0) # batch size of 1
|
||||
|
||||
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
|
||||
... labels=token_labels)
|
||||
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.layoutlm(
|
||||
|
@ -1182,6 +1182,15 @@ class LayoutLMForMaskedLM:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LayoutLMForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LayoutLMForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -17,19 +17,26 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import LayoutLMConfig, LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
LayoutLMConfig,
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
)
|
||||
|
||||
|
||||
class LayoutLMModelTester:
|
||||
"""You can also import this e.g from .test_modeling_bart import BartModelTester """
|
||||
"""You can also import this e.g from .test_modeling_layoutlm import LayoutLMModelTester """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -150,6 +157,18 @@ class LayoutLMModelTester:
|
||||
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LayoutLMForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
@ -185,7 +204,14 @@ class LayoutLMModelTester:
|
||||
class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(LayoutLMModel, LayoutLMForMaskedLM, LayoutLMForTokenClassification) if is_torch_available() else ()
|
||||
(
|
||||
LayoutLMModel,
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else None
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
@ -209,36 +235,101 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
@cached_property
|
||||
def big_model(self):
|
||||
"""Cached property means this code will only be executed once."""
|
||||
checkpoint_path = "microsoft/layoutlm-large-uncased"
|
||||
model = LayoutLMForMaskedLM.from_pretrained(checkpoint_path).to(
|
||||
torch_device
|
||||
) # test whether AutoModel can determine your model_class from checkpoint name
|
||||
if torch_device == "cuda":
|
||||
model.half()
|
||||
|
||||
# optional: do more testing! This will save you time later!
|
||||
def prepare_layoutlm_batch_inputs():
|
||||
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
|
||||
# fmt: off
|
||||
input_ids = torch.tensor([[-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-9997.22461,-16.2628059,-10004.082,15.4330549,15.4330549,15.4330549,-9990.42,-16.3270779,-16.3270779,-16.3270779,-16.3270779,-16.3270779,-10004.8506]],device=torch_device) # noqa: E231
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],],device=torch_device) # noqa: E231
|
||||
bbox = torch.tensor([[[0,0,0,0],[423,237,440,251],[427,272,441,287],[419,115,437,129],[961,885,992,912],[256,38,330,58],[256,38,330,58],[336,42,353,57],[360,39,401,56],[360,39,401,56],[411,39,471,59],[479,41,528,59],[533,39,630,60],[67,113,134,131],[141,115,209,132],[68,149,133,166],[141,149,187,164],[195,148,287,165],[195,148,287,165],[195,148,287,165],[295,148,349,165],[441,149,492,166],[497,149,546,164],[64,201,125,218],[1000,1000,1000,1000]],[[0,0,0,0],[662,150,754,166],[665,199,742,211],[519,213,554,228],[519,213,554,228],[134,433,187,454],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[314,469,376,482],[504,684,582,706],[941,825,973,900],[941,825,973,900],[941,825,973,900],[941,825,973,900],[610,749,652,765],[130,659,168,672],[176,657,237,672],[238,657,312,672],[443,653,628,672],[443,653,628,672],[716,301,825,317],[1000,1000,1000,1000]]],device=torch_device) # noqa: E231
|
||||
token_type_ids = torch.tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]],device=torch_device) # noqa: E231
|
||||
# these are sequence labels (i.e. at the token level)
|
||||
labels = torch.tensor([[-100,10,10,10,9,1,-100,7,7,-100,7,7,4,2,5,2,8,8,-100,-100,5,0,3,2,-100],[-100,12,12,12,-100,12,10,-100,-100,-100,-100,10,12,9,-100,-100,-100,10,10,10,9,12,-100,10,-100]],device=torch_device) # noqa: E231
|
||||
# fmt: on
|
||||
|
||||
return input_ids, attention_mask, bbox, token_type_ids, labels
|
||||
|
||||
|
||||
@require_torch
|
||||
class LayoutLMModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_that_LayoutLM_can_be_used_in_a_pipeline(self):
|
||||
"""We can use self.big_model here without calling __init__ again."""
|
||||
pass
|
||||
def test_forward_pass_no_head(self):
|
||||
model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased").to(torch_device)
|
||||
|
||||
def test_LayoutLM_loss_doesnt_change_if_you_add_padding(self):
|
||||
pass
|
||||
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
|
||||
|
||||
def test_LayoutLM_bad_args(self):
|
||||
pass
|
||||
# forward pass
|
||||
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
|
||||
def test_LayoutLM_backward_pass_reduces_loss(self):
|
||||
"""Test loss/gradients same as reference implementation, for example."""
|
||||
pass
|
||||
# test the sequence output on [0, :3, :3]
|
||||
expected_slice = torch.tensor(
|
||||
[[0.1785, -0.1947, -0.0425], [-0.3254, -0.2807, 0.2553], [-0.5391, -0.3322, 0.3364]],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_large_inputs_in_fp16_dont_cause_overflow(self):
|
||||
pass
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-3))
|
||||
|
||||
# test the pooled output on [1, :3]
|
||||
expected_slice = torch.tensor([-0.6580, -0.0214, 0.8552], device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.pooler_output[1, :3], expected_slice, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_forward_pass_sequence_classification(self):
|
||||
# initialize model with randomly initialized sequence classification head
|
||||
model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=2).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
input_ids, attention_mask, bbox, token_type_ids, _ = prepare_layoutlm_batch_inputs()
|
||||
|
||||
# forward pass
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
bbox=bbox,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=torch.tensor([1, 1], device=torch_device),
|
||||
)
|
||||
|
||||
# test whether we get a loss as a scalar
|
||||
loss = outputs.loss
|
||||
expected_shape = torch.Size([])
|
||||
self.assertEqual(loss.shape, expected_shape)
|
||||
|
||||
# test the shape of the logits
|
||||
logits = outputs.logits
|
||||
expected_shape = torch.Size((2, 2))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_forward_pass_token_classification(self):
|
||||
# initialize model with randomly initialized token classification head
|
||||
model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=13).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
|
||||
|
||||
# forward pass
|
||||
outputs = model(
|
||||
input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels
|
||||
)
|
||||
|
||||
# test the loss calculation to be around 2.65
|
||||
expected_loss = torch.tensor(2.65, device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=0.1))
|
||||
|
||||
# test the shape of the logits
|
||||
logits = outputs.logits
|
||||
expected_shape = torch.Size((2, 25, 13))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user