mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00

* Rebase with master * Minor bug fix in docs * Copy files from adding_luke_v2 and improve docs * change the default value of use_entity_aware_attention to True * remove word_hidden_states * fix head models * fix tests * fix the conversion script * add integration tests for the pretrained large model * improve docstring * Improve docs, make style * fix _init_weights for pytorch 1.8 * improve docs * fix tokenizer to construct entity sequence with [MASK] entity when entities=None * Make fix-copies * Make style & quality * Bug fixes * Add LukeTokenizer to init * Address most comments by @patil-suraj and @LysandreJik * rename _compute_extended_attention_mask to get_extended_attention_mask * add comments to LukeSelfAttention * fix the documentation of the tokenizer * address comments by @patil-suraj, @LysandreJik, and @sgugger * improve docs * Make style, quality and fix-copies * Improve docs * fix docs * add "entity_span_classification" task * update example code for LukeForEntitySpanClassification * improve docs * improve docs * improve the code example in luke.rst * rename the classification layer in LukeForEntityClassification from typing to classifier * add bias to the classifier in LukeForEntitySpanClassification * update docs to use fine-tuned hub models in code examples of the head models * update the example sentences * Make style & quality * Add require_torch to tokenizer tests * Add require_torch to tokenizer tests * Address comments by @sgugger and add community notebooks * Make fix-copies Co-authored-by: Ikuya Yamada <ikuya@ikuya.net>
160 lines
9.5 KiB
ReStructuredText
160 lines
9.5 KiB
ReStructuredText
..
|
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
LUKE
|
|
-----------------------------------------------------------------------------------------------------------------------
|
|
|
|
Overview
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The LUKE model was proposed in `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention
|
|
<https://arxiv.org/abs/2010.01057>`_ by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda and Yuji Matsumoto.
|
|
It is based on RoBERTa and adds entity embeddings as well as an entity-aware self-attention mechanism, which helps
|
|
improve performance on various downstream tasks involving reasoning about entities such as named entity recognition,
|
|
extractive and cloze-style question answering, entity typing, and relation classification.
|
|
|
|
The abstract from the paper is the following:
|
|
|
|
*Entity representations are useful in natural language tasks involving entities. In this paper, we propose new
|
|
pretrained contextualized representations of words and entities based on the bidirectional transformer. The proposed
|
|
model treats words and entities in a given text as independent tokens, and outputs contextualized representations of
|
|
them. Our model is trained using a new pretraining task based on the masked language model of BERT. The task involves
|
|
predicting randomly masked words and entities in a large entity-annotated corpus retrieved from Wikipedia. We also
|
|
propose an entity-aware self-attention mechanism that is an extension of the self-attention mechanism of the
|
|
transformer, and considers the types of tokens (words or entities) when computing attention scores. The proposed model
|
|
achieves impressive empirical performance on a wide range of entity-related tasks. In particular, it obtains
|
|
state-of-the-art results on five well-known datasets: Open Entity (entity typing), TACRED (relation classification),
|
|
CoNLL-2003 (named entity recognition), ReCoRD (cloze-style question answering), and SQuAD 1.1 (extractive question
|
|
answering).*
|
|
|
|
Tips:
|
|
|
|
- This implementation is the same as :class:`~transformers.RobertaModel` with the addition of entity embeddings as well
|
|
as an entity-aware self-attention mechanism, which improves performance on tasks involving reasoning about entities.
|
|
- LUKE treats entities as input tokens; therefore, it takes :obj:`entity_ids`, :obj:`entity_attention_mask`,
|
|
:obj:`entity_token_type_ids` and :obj:`entity_position_ids` as extra input. You can obtain those using
|
|
:class:`~transformers.LukeTokenizer`.
|
|
- :class:`~transformers.LukeTokenizer` takes :obj:`entities` and :obj:`entity_spans` (character-based start and end
|
|
positions of the entities in the input text) as extra input. :obj:`entities` typically consist of [MASK] entities or
|
|
Wikipedia entities. The brief description when inputting these entities are as follows:
|
|
|
|
- *Inputting [MASK] entities to compute entity representations*: The [MASK] entity is used to mask entities to be
|
|
predicted during pretraining. When LUKE receives the [MASK] entity, it tries to predict the original entity by
|
|
gathering the information about the entity from the input text. Therefore, the [MASK] entity can be used to address
|
|
downstream tasks requiring the information of entities in text such as entity typing, relation classification, and
|
|
named entity recognition.
|
|
- *Inputting Wikipedia entities to compute knowledge-enhanced token representations*: LUKE learns rich information
|
|
(or knowledge) about Wikipedia entities during pretraining and stores the information in its entity embedding. By
|
|
using Wikipedia entities as input tokens, LUKE outputs token representations enriched by the information stored in
|
|
the embeddings of these entities. This is particularly effective for tasks requiring real-world knowledge, such as
|
|
question answering.
|
|
|
|
- There are three head models for the former use case:
|
|
|
|
- :class:`~transformers.LukeForEntityClassification`, for tasks to classify a single entity in an input text such as
|
|
entity typing, e.g. the `Open Entity dataset <https://www.cs.utexas.edu/~eunsol/html_pages/open_entity.html>`__.
|
|
This model places a linear head on top of the output entity representation.
|
|
- :class:`~transformers.LukeForEntityPairClassification`, for tasks to classify the relationship between two entities
|
|
such as relation classification, e.g. the `TACRED dataset <https://nlp.stanford.edu/projects/tacred/>`__. This
|
|
model places a linear head on top of the concatenated output representation of the pair of given entities.
|
|
- :class:`~transformers.LukeForEntitySpanClassification`, for tasks to classify the sequence of entity spans, such as
|
|
named entity recognition (NER). This model places a linear head on top of the output entity representations. You
|
|
can address NER using this model by inputting all possible entity spans in the text to the model.
|
|
|
|
:class:`~transformers.LukeTokenizer` has a ``task`` argument, which enables you to easily create an input to these
|
|
head models by specifying ``task="entity_classification"``, ``task="entity_pair_classification"``, or
|
|
``task="entity_span_classification"``. Please refer to the example code of each head models.
|
|
|
|
There are also 3 notebooks available, which showcase how you can reproduce the results as reported in the paper with
|
|
the HuggingFace implementation of LUKE. They can be found `here
|
|
<https://github.com/studio-ousia/luke/tree/master/notebooks>`__.
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
>>> from transformers import LukeTokenizer, LukeModel, LukeForEntityPairClassification
|
|
|
|
>>> model = LukeModel.from_pretrained("studio-ousia/luke-base")
|
|
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base")
|
|
|
|
# Example 1: Computing the contextualized entity representation corresponding to the entity mention "Beyoncé"
|
|
>>> text = "Beyoncé lives in Los Angeles."
|
|
>>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
|
|
>>> inputs = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
>>> word_last_hidden_state = outputs.last_hidden_state
|
|
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
|
|
|
|
# Example 2: Inputting Wikipedia entities to obtain enriched contextualized representations
|
|
>>> entities = ["Beyoncé", "Los Angeles"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles"
|
|
>>> entity_spans = [(0, 7), (17, 28)] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
|
|
>>> inputs = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
>>> word_last_hidden_state = outputs.last_hidden_state
|
|
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
|
|
|
|
# Example 3: Classifying the relationship between two entities using LukeForEntityPairClassification head model
|
|
>>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
|
|
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
|
|
>>> entity_spans = [(0, 7), (17, 28)] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
|
|
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
>>> logits = outputs.logits
|
|
>>> predicted_class_idx = int(logits[0].argmax())
|
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
|
|
|
This model was contributed by `ikuyamada <https://huggingface.co/ikuyamada>`__ and `nielsr
|
|
<https://huggingface.co/nielsr>`__. The original code can be found `here <https://github.com/studio-ousia/luke>`__.
|
|
|
|
|
|
LukeConfig
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.LukeConfig
|
|
:members:
|
|
|
|
|
|
LukeTokenizer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.LukeTokenizer
|
|
:members: __call__, save_vocabulary
|
|
|
|
|
|
LukeModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.LukeModel
|
|
:members: forward
|
|
|
|
|
|
LukeForEntityClassification
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.LukeForEntityClassification
|
|
:members: forward
|
|
|
|
|
|
LukeForEntityPairClassification
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.LukeForEntityPairClassification
|
|
:members: forward
|
|
|
|
|
|
LukeForEntitySpanClassification
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.LukeForEntitySpanClassification
|
|
:members: forward
|