update docstrings; rename lm_labels to more explicit ltr_lm_labels

This commit is contained in:
Rémi Louf 2019-10-29 20:08:03 +01:00
parent dfce409691
commit 098a89f312
2 changed files with 32 additions and 27 deletions

View File

@ -26,7 +26,7 @@ import numpy as np
from tqdm import tqdm, trange
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import (
AutoTokenizer,
@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, ltr_lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
ltr_lm_labels = ltr_lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""):
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
decoder_ltr_lm_labels=ltr_lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()

View File

@ -548,6 +548,14 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
is configured as a decoder.
**encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
"""
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
@ -609,26 +617,18 @@ class BertModel(BertPreTrainedModel):
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
""" Forward pass on the Model.
The values of the attention matrix (shape [batch_size, seq_length])
should be 1.0 for the position we want to attend to and 0. for the ones
we do not want to attend to.
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
ever self-attention layer, following the architecture described in [1].
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave like as a decoder the model needs to be initialized with the
`is_decoder` argument of the config set to `True`. An
To behave as an decoder the model needs to be initialized with the
`is_decoder` argument of the configuration set to `True`; an
`encoder_hidden_states` is expected as an input to the forward pass.
When a decoder, there are two kinds of attention masks to specify:
(1) Self-attention masks that need to be causal (only attends to
previous tokens);
(2) A cross-attention mask that prevents the module
from attending to the encoder's padding tokens.
.. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
neural information processing systems. 2017.
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
@ -791,11 +791,16 @@ class BertForMaskedLM(BertPreTrainedModel):
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
**ltr_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
**ltr_lm_loss**: (`optional`, returned when ``ltr_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next token prediction loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
@ -833,7 +838,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, ltr_lm_labels=None, ):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
@ -852,22 +857,22 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If `lm_label` is provided we are in a causal scenario where we
# try to predict the next word for each input in the encoder.
# 2. If `ltr_lm_labels` is provided we are in a causal scenario where we
# try to predict the next token for each input in the decoder.
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs
if lm_labels is not None:
if ltr_lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous()
ltr_lm_labels = ltr_lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (next_token_loss,) + outputs
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), ltr_lm_labels.view(-1))
outputs = (ltr_lm_loss,) + outputs
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,