transformers/docs/source/tasks/token_classification.mdx
Steven Liu 7732148124
Adopt framework-specific blocks for content (#16342)
*  refactor code samples with framework-specific blocks

*  update training.mdx

* 🖍 apply feedback
2022-03-22 16:14:58 -05:00

272 lines
9.9 KiB
Plaintext

<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Token classification
<Youtube id="wVHdVlPScxA"/>
Token classification assigns a label to individual tokens in a sentence. One of the most common token classification tasks is Named Entity Recognition (NER). NER attempts to find a label for each entity in a sentence, such as a person, location, or organization.
This guide will show you how to fine-tune [DistilBERT](https://huggingface.co/distilbert-base-uncased) on the [WNUT 17](https://huggingface.co/datasets/wnut_17) dataset to detect new entities.
<Tip>
See the token classification [task page](https://huggingface.co/tasks/token-classification) for more information about other forms of token classification and their associated models, datasets, and metrics.
</Tip>
## Load WNUT 17 dataset
Load the WNUT 17 dataset from the 🤗 Datasets library:
```py
>>> from datasets import load_dataset
>>> wnut = load_dataset("wnut_17")
```
Then take a look at an example:
```py
>>> wnut["train"][0]
{'id': '0',
'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
'tokens': ['@paulwalk', 'It', "'s", 'the', 'view', 'from', 'where', 'I', "'m", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.']
}
```
Each number in `ner_tags` represents an entity. Convert the number to a label name for more information:
```py
>>> label_list = wnut["train"].features[f"ner_tags"].feature.names
>>> label_list
[
"O",
"B-corporation",
"I-corporation",
"B-creative-work",
"I-creative-work",
"B-group",
"I-group",
"B-location",
"I-location",
"B-person",
"I-person",
"B-product",
"I-product",
]
```
The `ner_tag` describes an entity, such as a corporation, location, or person. The letter that prefixes each `ner_tag` indicates the token position of the entity:
- `B-` indicates the beginning of an entity.
- `I-` indicates a token is contained inside the same entity (e.g., the `State` token is a part of an entity like
`Empire State Building`).
- `0` indicates the token doesn't correspond to any entity.
## Preprocess
<Youtube id="iY2AZYdZAr0"/>
Load the DistilBERT tokenizer to process the `tokens`:
```py
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
```
Since the input has already been split into words, set `is_split_into_words=True` to tokenize the words into subwords:
```py
>>> tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
>>> tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
>>> tokens
['[CLS]', '@', 'paul', '##walk', 'it', "'", 's', 'the', 'view', 'from', 'where', 'i', "'", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]']
```
Adding the special tokens `[CLS]` and `[SEP]` and subword tokenization creates a mismatch between the input and labels. A single word corresponding to a single label may be split into two subwords. You will need to realign the tokens and labels by:
1. Mapping all tokens to their corresponding word with the [`word_ids`](https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Encoding.word_ids) method.
2. Assigning the label `-100` to the special tokens `[CLS]` and `[SEP]` so the PyTorch loss function ignores
them.
3. Only labeling the first token of a given word. Assign `-100` to other subtokens from the same word.
Here is how you can create a function to realign the tokens and labels, and truncate sequences to be no longer than DistilBERT's maximum input length::
```py
>>> def tokenize_and_align_labels(examples):
... tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
... labels = []
... for i, label in enumerate(examples[f"ner_tags"]):
... word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word.
... previous_word_idx = None
... label_ids = []
... for word_idx in word_ids: # Set the special tokens to -100.
... if word_idx is None:
... label_ids.append(-100)
... elif word_idx != previous_word_idx: # Only label the first token of a given word.
... label_ids.append(label[word_idx])
... else:
... label_ids.append(-100)
... previous_word_idx = word_idx
... labels.append(label_ids)
... tokenized_inputs["labels"] = labels
... return tokenized_inputs
```
Use 🤗 Datasets [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) function to tokenize and align the labels over the entire dataset. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once:
```py
>>> tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)
```
Use [`DataCollatorForTokenClassification`] to create a batch of examples. It will also *dynamically pad* your text and labels to the length of the longest element in its batch, so they are a uniform length. While it is possible to pad your text in the `tokenizer` function by setting `padding=True`, dynamic padding is more efficient.
<frameworkcontent>
<pt>
```py
>>> from transformers import DataCollatorForTokenClassification
>>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
```
</pt>
<tf>
```py
>>> from transformers import DataCollatorForTokenClassification
>>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="tf")
```
</tf>
</frameworkcontent>
## Train
<frameworkcontent>
<pt>
Load DistilBERT with [`AutoModelForTokenClassification`] along with the number of expected labels:
```py
>>> from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
>>> model = AutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=14)
```
<Tip>
If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](../training#finetune-with-trainer)!
</Tip>
At this point, only three steps remain:
1. Define your training hyperparameters in [`TrainingArguments`].
2. Pass the training arguments to [`Trainer`] along with the model, dataset, tokenizer, and data collator.
3. Call [`~Trainer.train`] to fine-tune your model.
```py
>>> training_args = TrainingArguments(
... output_dir="./results",
... evaluation_strategy="epoch",
... learning_rate=2e-5,
... per_device_train_batch_size=16,
... per_device_eval_batch_size=16,
... num_train_epochs=3,
... weight_decay=0.01,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=tokenized_wnut["train"],
... eval_dataset=tokenized_wnut["test"],
... tokenizer=tokenizer,
... data_collator=data_collator,
... )
>>> trainer.train()
```
</pt>
<tf>
To fine-tune a model in TensorFlow, start by converting your datasets to the `tf.data.Dataset` format with [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Specify inputs and labels in `columns`, whether to shuffle the dataset order, batch size, and the data collator:
```py
>>> tf_train_set = tokenized_wnut["train"].to_tf_dataset(
... columns=["attention_mask", "input_ids", "labels"],
... shuffle=True,
... batch_size=16,
... collate_fn=data_collator,
... )
>>> tf_validation_set = tokenized_wnut["validation"].to_tf_dataset(
... columns=["attention_mask", "input_ids", "labels"],
... shuffle=False,
... batch_size=16,
... collate_fn=data_collator,
... )
```
<Tip>
If you aren't familiar with fine-tuning a model with Keras, take a look at the basic tutorial [here](training#finetune-with-keras)!
</Tip>
Set up an optimizer function, learning rate schedule, and some training hyperparameters:
```py
>>> from transformers import create_optimizer
>>> batch_size = 16
>>> num_train_epochs = 3
>>> num_train_steps = (len(tokenized_wnut["train"]) // batch_size) * num_train_epochs
>>> optimizer, lr_schedule = create_optimizer(
... init_lr=2e-5,
... num_train_steps=num_train_steps,
... weight_decay_rate=0.01,
... num_warmup_steps=0,
... )
```
Load DistilBERT with [`TFAutoModelForTokenClassification`] along with the number of expected labels:
```py
>>> from transformers import TFAutoModelForTokenClassification
>>> model = TFAutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
```
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
```py
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
```
Call [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) to fine-tune the model:
```py
>>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3)
```
</tf>
</frameworkcontent>
<Tip>
For a more in-depth example of how to fine-tune a model for token classification, take a look at the corresponding
[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/token_classification.ipynb)
or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/token_classification-tf.ipynb).
</Tip>