* Added support for Albert when fine-tuning for NER

* Added support for Albert in NER pipeline

* Added command-line options to examples/ner/run_ner.py to better control tokenization

* Added class AlbertForTokenClassification

* Changed output for NerPipeline to use .convert_ids_to_tokens(...) instead of .decode(...) to better reflect tokens
This commit is contained in:
Martin Malmsten 2020-02-23 21:05:22 +01:00
parent 129f0604ac
commit 869b66f6b3
5 changed files with 129 additions and 5 deletions

View File

@ -33,6 +33,9 @@ from tqdm import tqdm, trange
from transformers import (
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertForTokenClassification,
AlbertTokenizer,
BertConfig,
BertForTokenClassification,
BertTokenizer,
@ -70,6 +73,7 @@ ALL_MODELS = sum(
)
MODEL_CLASSES = {
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
@ -77,6 +81,8 @@ MODEL_CLASSES = {
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
}
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
def set_seed(args):
random.seed(args.seed)
@ -463,6 +469,22 @@ def main():
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument(
"--keep_accents", action="store_const", const=True, help="Set this flag if model is trained with accents."
)
parser.add_argument(
"--strip_accents", action="store_const", const=True, help="Set this flag if model is trained without accents."
)
parser.add_argument(
"--nouse_fast",
action="store_const",
dest="use_fast",
const=False,
help="Set this flag to not use fast tokenization.",
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument(
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
@ -590,10 +612,12 @@ def main():
label2id={label: i for i, label in enumerate(labels)},
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer_args = {k: v for k, v in vars(args).items() if v != None and k in TOKENIZER_ARGS}
logger.info("Tokenizer arguments: %s", tokenizer_args)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
**tokenizer_args
)
model = model_class.from_pretrained(
args.model_name_or_path,
@ -636,7 +660,7 @@ def main():
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
@ -658,7 +682,7 @@ def main():
writer.write("{} = {}\n".format(key, str(results[key])))
if args.do_predict and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
model = model_class.from_pretrained(args.output_dir)
model.to(args.device)
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")

View File

@ -255,6 +255,7 @@ if is_torch_available():
AlbertForMaskedLM,
AlbertForSequenceClassification,
AlbertForQuestionAnswering,
AlbertForTokenClassification,
load_tf_weights_in_albert,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)

View File

@ -600,7 +600,7 @@ class AlbertMLMHead(nn.Module):
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.decoder(hidden_states)
prediction_scores = hidden_states
prediction_scores = hidden_states + self.bias
return prediction_scores
@ -788,6 +788,103 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Albert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.albert = AlbertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.init_weights()
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import AlbertTokenizer, AlbertForTokenClassification
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForTokenClassification.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,

View File

@ -42,6 +42,7 @@ from .modeling_albert import (
AlbertForMaskedLM,
AlbertForQuestionAnswering,
AlbertForSequenceClassification,
AlbertForTokenClassification,
AlbertModel,
)
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
@ -233,6 +234,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(RobertaConfig, RobertaForTokenClassification),
(BertConfig, BertForTokenClassification),
(XLNetConfig, XLNetForTokenClassification),
(AlbertConfig, AlbertForTokenClassification),
]
)

View File

@ -636,7 +636,7 @@ class NerPipeline(Pipeline):
if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [
{
"word": self.tokenizer.decode([int(input_ids[idx])]),
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
}