Add LayoutLM Model (#7064)

* first version

* finish test docs readme model/config/tokenization class

* apply make style and make quality

* fix layoutlm GitHub link

* fix conflict in index.rst and add layoutlm to pretrained_models.rst

* fix bug in test_parents_and_children_in_mappings

* reformat modeling_auto.py and tokenization_auto.py

* fix bug in test_modeling_layoutlm.py

* Update docs/source/model_doc/layoutlm.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docs/source/model_doc/layoutlm.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove inh, add tokenizer fast, and update some doc

* copy and rename necessary class from modeling_bert to modeling_layoutlm

* Update src/transformers/configuration_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/configuration_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/configuration_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/configuration_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_layoutlm.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* add mish to activations.py, import ACT2FN and import logging from utils

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Minghao Li 2020-09-22 21:28:02 +08:00 committed by GitHub
parent 244e1b5ba3
commit cd9a0585ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1510 additions and 3 deletions

View File

@ -183,8 +183,9 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
24. **[MBart](https://github.com/pytorch/fairseq/tree/master/examples/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 24. **[MBart](https://github.com/pytorch/fairseq/tree/master/examples/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
25. **[LXMERT](https://github.com/airsplay/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 25. **[LXMERT](https://github.com/airsplay/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
26. **[Funnel Transformer](https://github.com/laiguokun/Funnel-Transformer)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 26. **[Funnel Transformer](https://github.com/laiguokun/Funnel-Transformer)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
27. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). 27. **[LayoutLM](https://github.com/microsoft/unilm/tree/master/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
28. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. 28. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
29. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations. You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html). These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations. You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).

View File

@ -137,7 +137,10 @@ conversion utilities for the following models:
27. `Bert For Sequence Generation <https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder>`_ (from Google) released with the paper 27. `Bert For Sequence Generation <https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder>`_ (from Google) released with the paper
`Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
<https://arxiv.org/abs/1907.12461>`_ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. <https://arxiv.org/abs/1907.12461>`_ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
28. `Other community models <https://huggingface.co/models>`_, contributed by the `community 28. `LayoutLM <https://github.com/microsoft/unilm/tree/master/layoutlm>`_ (from Microsoft Research Asia) released with the paper
`LayoutLM: Pre-training of Text and Layout for Document Image Understanding
<https://arxiv.org/abs/1912.13318>`_ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
29. `Other community models <https://huggingface.co/models>`_, contributed by the `community
<https://huggingface.co/users>`_. <https://huggingface.co/users>`_.
.. toctree:: .. toctree::
@ -227,6 +230,7 @@ conversion utilities for the following models:
model_doc/funnel model_doc/funnel
model_doc/lxmert model_doc/lxmert
model_doc/bertgeneration model_doc/bertgeneration
model_doc/layoutlm
internal/modeling_utils internal/modeling_utils
internal/tokenization_utils internal/tokenization_utils
internal/pipelines_utils internal/pipelines_utils

View File

@ -0,0 +1,55 @@
LayoutLM
----------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~
The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding <https://arxiv.org/abs/1912.13318>`__
by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, and Ming Zhou. It's a simple but effective pre-training method
of text and layout for document image understanding and information extraction tasks, such as form understanding and receipt understanding.
The abstract from the paper is the following:
*Pre-training techniques have been verified successfully in a variety of NLP tasks in recent years. Despite the widespread use of pre-training models for NLP applications, they almost exclusively focus on text-level manipulation, while neglecting layout and style information that is vital for document image understanding. In this paper, we propose the \textbf{LayoutLM} to jointly model interactions between text and layout information across scanned document images, which is beneficial for a great number of real-world document image understanding tasks such as information extraction from scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into LayoutLM. To the best of our knowledge, this is the first time that text and layout are jointly learned in a single framework for document-level pre-training. It achieves new state-of-the-art results in several downstream tasks, including form understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image classification (from 93.07 to 94.42).*
Tips:
- LayoutLM has an extra input called :obj:`bbox`, which is the bounding boxes of the input tokens.
- The :obj:`bbox` requires the data that on 0-1000 scale, which means you should normalize the bounding box before passing them into model.
The original code can be found `here <https://github.com/microsoft/unilm/tree/master/layoutlm>`_.
LayoutLMConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LayoutLMConfig
:members:
LayoutLMTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LayoutLMTokenizer
:members:
LayoutLMModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LayoutLMModel
:members:
LayoutLMForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LayoutLMForMaskedLM
:members:
LayoutLMForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LayoutLMForTokenClassification
:members:

View File

@ -408,3 +408,11 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | | | | | | |
| | | (see `details <https://github.com/laiguokun/Funnel-Transformer>`__) | | | | (see `details <https://github.com/laiguokun/Funnel-Transformer>`__) |
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| LayoutLM | ``microsoft/layoutlm-base-uncased`` | | 12 layers, 768-hidden, 12-heads, 113M parameters |
| | | |
| | | (see `details <https://github.com/microsoft/unilm/tree/master/layoutlm>`__) |
+ +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``microsoft/layoutlm-large-uncased`` | | 24 layers, 1024-hidden, 16-heads, 343M parameters |
| | | |
| | | (see `details <https://github.com/microsoft/unilm/tree/master/layoutlm>`__) |
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+

View File

@ -33,6 +33,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
@ -163,6 +164,7 @@ from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_fsmt import FSMTTokenizer from .tokenization_fsmt import FSMTTokenizer
from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_layoutlm import LayoutLMTokenizer, LayoutLMTokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
from .tokenization_mbart import MBartTokenizer from .tokenization_mbart import MBartTokenizer
@ -363,6 +365,12 @@ if is_torch_available():
GPT2PreTrainedModel, GPT2PreTrainedModel,
load_tf_weights_in_gpt2, load_tf_weights_in_gpt2,
) )
from .modeling_layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
LayoutLMForTokenClassification,
LayoutLMModel,
)
from .modeling_longformer import ( from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM, LongformerForMaskedLM,

View File

@ -55,3 +55,7 @@ def get_activation(activation_string):
return ACT2FN[activation_string] return ACT2FN[activation_string]
else: else:
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))

View File

@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
@ -73,6 +74,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items() for key, value, in pretrained_map.items()
) )
@ -108,6 +110,7 @@ CONFIG_MAPPING = OrderedDict(
("encoder-decoder", EncoderDecoderConfig), ("encoder-decoder", EncoderDecoderConfig),
("funnel", FunnelConfig), ("funnel", FunnelConfig),
("lxmert", LxmertConfig), ("lxmert", LxmertConfig),
("layoutlm", LayoutLMConfig),
] ]
) )
@ -141,6 +144,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("encoder-decoder", "Encoder decoder"), ("encoder-decoder", "Encoder decoder"),
("funnel", "Funnel Transformer"), ("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"), ("lxmert", "LXMERT"),
("layoutlm", "LayoutLM"),
] ]
) )

View File

@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2010, The Microsoft Research Asia LayoutLM Team authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LayoutLM model configuration """
from .configuration_bert import BertConfig
from .utils import logging
logger = logging.get_logger(__name__)
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"layoutlm-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/layoutlm-base-uncased/config.json",
"layoutlm-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/layoutlm-large-uncased/config.json",
}
class LayoutLMConfig(BertConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.LayoutLMModel`.
It is used to instantiate a LayoutLM model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the LayoutLM `layoutlm-base-uncased <https://huggingface.co/microsoft/layoutlm-base-uncased>`__ architecture.
Configuration objects inherit from :class:`~transformers.BertConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.BertConfig`
for more information.
Args:
vocab_size (:obj:`int`, optional, defaults to 30522):
Vocabulary size of the LayoutLM model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.LayoutLMModel`.
hidden_size (:obj:`int`, optional, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, optional, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, optional, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
The non-linear activation function (function or string) in the encoder and pooler.
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, optional, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, optional, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
max_2d_position_embeddings (:obj:`int`, optional, defaults to 1024):
The maximum value that the 2D position embedding might ever used.
Typically set this to something large just in case (e.g., 1024).
Example::
>>> from transformers import LayoutLMModel, LayoutLMConfig
>>> # Initializing a LayoutLM configuration
>>> configuration = LayoutLMConfig()
>>> # Initializing a model from the configuration
>>> model = LayoutLMModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "layoutlm"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
max_2d_position_embeddings=1024,
**kwargs
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
self.max_2d_position_embeddings = max_2d_position_embeddings

View File

@ -33,6 +33,7 @@ from .configuration_auto import (
FSMTConfig, FSMTConfig,
FunnelConfig, FunnelConfig,
GPT2Config, GPT2Config,
LayoutLMConfig,
LongformerConfig, LongformerConfig,
LxmertConfig, LxmertConfig,
MBartConfig, MBartConfig,
@ -124,6 +125,7 @@ from .modeling_funnel import (
FunnelModel, FunnelModel,
) )
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model
from .modeling_layoutlm import LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
from .modeling_longformer import ( from .modeling_longformer import (
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForMultipleChoice, LongformerForMultipleChoice,
@ -206,6 +208,7 @@ MODEL_MAPPING = OrderedDict(
(BartConfig, BartModel), (BartConfig, BartModel),
(LongformerConfig, LongformerModel), (LongformerConfig, LongformerModel),
(RobertaConfig, RobertaModel), (RobertaConfig, RobertaModel),
(LayoutLMConfig, LayoutLMModel),
(BertConfig, BertModel), (BertConfig, BertModel),
(OpenAIGPTConfig, OpenAIGPTModel), (OpenAIGPTConfig, OpenAIGPTModel),
(GPT2Config, GPT2Model), (GPT2Config, GPT2Model),
@ -226,6 +229,7 @@ MODEL_MAPPING = OrderedDict(
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[ [
(LayoutLMConfig, LayoutLMForMaskedLM),
(RetriBertConfig, RetriBertModel), (RetriBertConfig, RetriBertModel),
(T5Config, T5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM),
@ -252,6 +256,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[ [
(LayoutLMConfig, LayoutLMForMaskedLM),
(T5Config, T5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM), (AlbertConfig, AlbertForMaskedLM),
@ -300,6 +305,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[ [
(LayoutLMConfig, LayoutLMForMaskedLM),
(DistilBertConfig, DistilBertForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM), (AlbertConfig, AlbertForMaskedLM),
(BartConfig, BartForConditionalGeneration), (BartConfig, BartForConditionalGeneration),
@ -370,6 +376,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[ [
(LayoutLMConfig, LayoutLMForTokenClassification),
(DistilBertConfig, DistilBertForTokenClassification), (DistilBertConfig, DistilBertForTokenClassification),
(CamembertConfig, CamembertForTokenClassification), (CamembertConfig, CamembertForTokenClassification),
(FlaubertConfig, FlaubertForTokenClassification), (FlaubertConfig, FlaubertForTokenClassification),

View File

@ -0,0 +1,899 @@
# coding=utf-8
# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LayoutLM model. """
import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from .activations import ACT2FN
from .configuration_layoutlm import LayoutLMConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, TokenClassifierOutput
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LayoutLMConfig"
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"layoutlm-base-uncased",
"layoutlm-large-uncased",
]
LayoutLMLayerNorm = torch.nn.LayerNorm
class LayoutLMEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super(LayoutLMEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
input_ids=None,
bbox=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
words_embeddings = inputs_embeds
position_embeddings = self.position_embeddings(position_ids)
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = (
words_embeddings
+ position_embeddings
+ left_position_embeddings
+ upper_position_embeddings
+ right_position_embeddings
+ lower_position_embeddings
+ h_position_embeddings
+ w_position_embeddings
+ token_type_embeddings
)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class LayoutLMSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class LayoutLMSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LayoutLMAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = LayoutLMSelfAttention(config)
self.output = LayoutLMSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class LayoutLMIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class LayoutLMOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LayoutLMLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = LayoutLMAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = LayoutLMAttention(config)
self.intermediate = LayoutLMIntermediate(config)
self.output = LayoutLMOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class LayoutLMEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class LayoutLMPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class LayoutLMPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class LayoutLMLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = LayoutLMPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class LayoutLMOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = LayoutLMLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class LayoutLMOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class LayoutLMPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = LayoutLMLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class LayoutLMPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = LayoutLMConfig
base_model_prefix = "layoutlm"
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LayoutLMLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LAYOUTLM_START_DOCSTRING = r""" The LayoutLM model was proposed in
`LayoutLM: Pre-training of Text and Layout for Document Image Understanding
<https://arxiv.org/abs/1912.13318>`__ by....
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.LayoutLMConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
LAYOUTLM_INPUTS_DOCSTRING = r"""
Inputs:
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.LayoutLMTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
bbox (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
Bounding Boxes of each input sequence tokens.
Selected in the range ``[0, config.max_2d_position_embeddings - 1]``.
`What are bboxes? <../glossary.html#position-ids>`_
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
plain tuple.
"""
@add_start_docstrings(
"The bare LayoutLM Model transformer outputting raw hidden-states without any specific head on top.",
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMModel(LayoutLMPreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
def __init__(self, config):
super(LayoutLMModel, self).__init__(config)
self.config = config
self.embeddings = LayoutLMEmbeddings(config)
self.encoder = LayoutLMEncoder(config)
self.pooler = LayoutLMPooler(config)
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
bbox=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
"""
input_ids (torch.LongTensor of shape (batch_size, sequence_length)):
Indices of input sequence tokens in the vocabulary.
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional):
Mask to avoid performing attention on padding token indices.
Mask values selected in [0, 1]: 1 for tokens that are NOT MASKED, 0 for MASKED tokens.
token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in [0, 1]: 0 corresponds to a sentence A token, 1 corresponds to a sentence B token
position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range [0, config.max_position_embeddings - 1].
head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in [0, 1]: 1 indicates the head is not masked, 0 indicates the head is masked.
inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional):
Optionally, instead of passing input_ids you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert input_ids indices into associated vectors than the models internal embedding lookup matrix.
output_attentions (bool, optional):
If set to True, the attentions tensors of all attention layers are returned.
output_hidden_states (bool, optional):
If set to True, the hidden states of all layers are returned.
return_dict (bool, optional):
If set to True, the model will return a ModelOutput instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if bbox is None:
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(
input_ids=input_ids,
bbox=bbox,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top. """, LAYOUTLM_START_DOCSTRING)
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
def __init__(self, config):
super().__init__(config)
self.layoutlm = LayoutLMModel(config)
self.cls = LayoutLMOnlyMLMHead(config)
self.init_weights()
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
def get_output_embeddings(self):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
bbox=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlm(
input_ids,
bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""LayoutLM Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.layoutlm = LayoutLMModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
@add_start_docstrings_to_callable(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
bbox=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -32,6 +32,7 @@ from .configuration_auto import (
FSMTConfig, FSMTConfig,
FunnelConfig, FunnelConfig,
GPT2Config, GPT2Config,
LayoutLMConfig,
LongformerConfig, LongformerConfig,
LxmertConfig, LxmertConfig,
MarianConfig, MarianConfig,
@ -64,6 +65,7 @@ from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_fsmt import FSMTTokenizer from .tokenization_fsmt import FSMTTokenizer
from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_layoutlm import LayoutLMTokenizer, LayoutLMTokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
from .tokenization_marian import MarianTokenizer from .tokenization_marian import MarianTokenizer
@ -107,6 +109,7 @@ TOKENIZER_MAPPING = OrderedDict(
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)), (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)), (FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)), (LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)), (GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
@ -117,6 +120,7 @@ TOKENIZER_MAPPING = OrderedDict(
(CTRLConfig, (CTRLTokenizer, None)), (CTRLConfig, (CTRLTokenizer, None)),
(FSMTConfig, (FSMTTokenizer, None)), (FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)), (BertGenerationConfig, (BertGenerationTokenizer, None)),
(LayoutLMConfig, (LayoutLMTokenizer, None)),
] ]
) )

View File

@ -0,0 +1,78 @@
# coding=utf-8
# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model LayoutLM."""
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"microsoft/layoutlm-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
"microsoft/layoutlm-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"microsoft/layoutlm-base-uncased": 512,
"microsoft/layoutlm-large-uncased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"microsoft/layoutlm-base-uncased": {"do_lower_case": True},
"microsoft/layoutlm-large-uncased": {"do_lower_case": True},
}
class LayoutLMTokenizer(BertTokenizer):
r"""
Constructs a LayoutLM tokenizer.
:class:`~transformers.LayoutLMTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
class LayoutLMTokenizerFast(BertTokenizerFast):
r"""
Constructs a "Fast" LayoutLMTokenizer.
:class:`~transformers.LayoutLMTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]

View File

@ -0,0 +1,239 @@
# coding=utf-8
# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, require_torch_and_cuda, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
from transformers import LayoutLMConfig, LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
class LayoutLMModelTester:
"""You can also import this e.g from .test_modeling_bart import BartModelTester """
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
range_bbox=1000,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.range_bbox = range_bbox
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
bbox = ids_tensor([self.batch_size, self.seq_length, 4], self.range_bbox)
# Ensure that bbox is legal
for i in range(bbox.shape[0]):
for j in range(bbox.shape[1]):
if bbox[i, j, 3] < bbox[i, j, 1]:
t = bbox[i, j, 3]
bbox[i, j, 3] = bbox[i, j, 1]
bbox[i, j, 1] = t
if bbox[i, j, 2] < bbox[i, j, 0]:
t = bbox[i, j, 2]
bbox[i, j, 2] = bbox[i, j, 0]
bbox[i, j, 0] = t
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LayoutLMConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LayoutLMModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, bbox, token_type_ids=token_type_ids)
result = model(input_ids, bbox)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_for_masked_lm(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LayoutLMForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_token_classification(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = LayoutLMForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
bbox,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"bbox": bbox,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
}
return config, inputs_dict
@require_torch
class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(LayoutLMModel, LayoutLMForMaskedLM, LayoutLMForTokenClassification) if is_torch_available() else ()
)
def setUp(self):
self.model_tester = LayoutLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=LayoutLMConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@cached_property
def big_model(self):
"""Cached property means this code will only be executed once."""
checkpoint_path = "microsoft/layoutlm-large-uncased"
model = LayoutLMForMaskedLM.from_pretrained(checkpoint_path).to(
torch_device
) # test whether AutoModel can determine your model_class from checkpoint name
if torch_device == "cuda":
model.half()
# optional: do more testing! This will save you time later!
@slow
def test_that_LayoutLM_can_be_used_in_a_pipeline(self):
"""We can use self.big_model here without calling __init__ again."""
pass
def test_LayoutLM_loss_doesnt_change_if_you_add_padding(self):
pass
def test_LayoutLM_bad_args(self):
pass
def test_LayoutLM_backward_pass_reduces_loss(self):
"""Test loss/gradients same as reference implementation, for example."""
pass
@require_torch_and_cuda
def test_large_inputs_in_fp16_dont_cause_overflow(self):
pass

View File

@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers.tokenization_layoutlm import VOCAB_FILES_NAMES, LayoutLMTokenizer
from .test_tokenization_common import TokenizerTesterMixin
class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LayoutLMTokenizer
def setUp(self):
super().setUp()
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return LayoutLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file)
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_special_tokens_as_you_expect(self):
"""If you are training a seq2seq model that expects a decoder_prefix token make sure it is prepended to decoder_input_ids """
pass