mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Added support for Albert when fine-tuning for NER
* Added support for Albert in NER pipeline * Added command-line options to examples/ner/run_ner.py to better control tokenization * Added class AlbertForTokenClassification * Changed output for NerPipeline to use .convert_ids_to_tokens(...) instead of .decode(...) to better reflect tokens
This commit is contained in:
parent
129f0604ac
commit
869b66f6b3
@ -33,6 +33,9 @@ from tqdm import tqdm, trange
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertForTokenClassification,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertForTokenClassification,
|
||||
BertTokenizer,
|
||||
@ -70,6 +73,7 @@ ALL_MODELS = sum(
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
||||
@ -77,6 +81,8 @@ MODEL_CLASSES = {
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
|
||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@ -463,6 +469,22 @@ def main():
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep_accents", action="store_const", const=True, help="Set this flag if model is trained with accents."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--strip_accents", action="store_const", const=True, help="Set this flag if model is trained without accents."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nouse_fast",
|
||||
action="store_const",
|
||||
dest="use_fast",
|
||||
const=False,
|
||||
help="Set this flag to not use fast tokenization.",
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
@ -590,10 +612,12 @@ def main():
|
||||
label2id={label: i for i, label in enumerate(labels)},
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer_args = {k: v for k, v in vars(args).items() if v != None and k in TOKENIZER_ARGS}
|
||||
logger.info("Tokenizer arguments: %s", tokenizer_args)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
**tokenizer_args
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
@ -636,7 +660,7 @@ def main():
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
@ -658,7 +682,7 @@ def main():
|
||||
writer.write("{} = {}\n".format(key, str(results[key])))
|
||||
|
||||
if args.do_predict and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
|
@ -255,6 +255,7 @@ if is_torch_available():
|
||||
AlbertForMaskedLM,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertForQuestionAnswering,
|
||||
AlbertForTokenClassification,
|
||||
load_tf_weights_in_albert,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
@ -600,7 +600,7 @@ class AlbertMLMHead(nn.Module):
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
prediction_scores = hidden_states
|
||||
prediction_scores = hidden_states + self.bias
|
||||
|
||||
return prediction_scores
|
||||
|
||||
@ -788,6 +788,103 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Albert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = AlbertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
||||
Classification loss.
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import AlbertTokenizer, AlbertForTokenClassification
|
||||
import torch
|
||||
|
||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||
model = AlbertForTokenClassification.from_pretrained('albert-base-v2')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
outputs = self.albert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
|
@ -42,6 +42,7 @@ from .modeling_albert import (
|
||||
AlbertForMaskedLM,
|
||||
AlbertForQuestionAnswering,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertForTokenClassification,
|
||||
AlbertModel,
|
||||
)
|
||||
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
|
||||
@ -233,6 +234,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(RobertaConfig, RobertaForTokenClassification),
|
||||
(BertConfig, BertForTokenClassification),
|
||||
(XLNetConfig, XLNetForTokenClassification),
|
||||
(AlbertConfig, AlbertForTokenClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -636,7 +636,7 @@ class NerPipeline(Pipeline):
|
||||
if self.model.config.id2label[label_idx] not in self.ignore_labels:
|
||||
answer += [
|
||||
{
|
||||
"word": self.tokenizer.decode([int(input_ids[idx])]),
|
||||
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
|
||||
"score": score[idx][label_idx].item(),
|
||||
"entity": self.model.config.id2label[label_idx],
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user