mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
update docstrings; rename lm_labels to more explicit ltr_lm_labels
This commit is contained in:
parent
dfce409691
commit
098a89f312
@ -26,7 +26,7 @@ import numpy as np
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, ltr_lm_labels = batch
|
||||||
|
|
||||||
source = source.to(args.device)
|
source = source.to(args.device)
|
||||||
target = target.to(args.device)
|
target = target.to(args.device)
|
||||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
||||||
encoder_mask = encoder_mask.to(args.device)
|
encoder_mask = encoder_mask.to(args.device)
|
||||||
decoder_mask = decoder_mask.to(args.device)
|
decoder_mask = decoder_mask.to(args.device)
|
||||||
lm_labels = lm_labels.to(args.device)
|
ltr_lm_labels = ltr_lm_labels.to(args.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(
|
outputs = model(
|
||||||
@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
encoder_token_type_ids=encoder_token_type_ids,
|
encoder_token_type_ids=encoder_token_type_ids,
|
||||||
encoder_attention_mask=encoder_mask,
|
encoder_attention_mask=encoder_mask,
|
||||||
decoder_attention_mask=decoder_mask,
|
decoder_attention_mask=decoder_mask,
|
||||||
decoder_lm_labels=lm_labels,
|
decoder_ltr_lm_labels=ltr_lm_labels,
|
||||||
)
|
)
|
||||||
lm_loss = outputs[0]
|
lm_loss = outputs[0]
|
||||||
eval_loss += lm_loss.mean().item()
|
eval_loss += lm_loss.mean().item()
|
||||||
|
@ -548,6 +548,14 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
|
||||||
|
is configured as a decoder.
|
||||||
|
**encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
||||||
|
is used in the cross-attention if the model is configured as a decoder.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -609,26 +617,18 @@ class BertModel(BertPreTrainedModel):
|
|||||||
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
""" Forward pass on the Model.
|
""" Forward pass on the Model.
|
||||||
|
|
||||||
The values of the attention matrix (shape [batch_size, seq_length])
|
|
||||||
should be 1.0 for the position we want to attend to and 0. for the ones
|
|
||||||
we do not want to attend to.
|
|
||||||
|
|
||||||
The model can behave as an encoder (with only self-attention) as well
|
The model can behave as an encoder (with only self-attention) as well
|
||||||
as a decoder, in which case a layer of cross-attention is added between
|
as a decoder, in which case a layer of cross-attention is added between
|
||||||
ever self-attention layer, following the architecture described in [1].
|
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
||||||
|
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
|
||||||
To behave like as a decoder the model needs to be initialized with the
|
To behave as an decoder the model needs to be initialized with the
|
||||||
`is_decoder` argument of the config set to `True`. An
|
`is_decoder` argument of the configuration set to `True`; an
|
||||||
`encoder_hidden_states` is expected as an input to the forward pass.
|
`encoder_hidden_states` is expected as an input to the forward pass.
|
||||||
When a decoder, there are two kinds of attention masks to specify:
|
|
||||||
|
|
||||||
(1) Self-attention masks that need to be causal (only attends to
|
.. _`Attention is all you need`:
|
||||||
previous tokens);
|
https://arxiv.org/abs/1706.03762
|
||||||
(2) A cross-attention mask that prevents the module
|
|
||||||
from attending to the encoder's padding tokens.
|
|
||||||
|
|
||||||
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
|
|
||||||
neural information processing systems. 2017.
|
|
||||||
"""
|
"""
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
@ -791,11 +791,16 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
in ``[0, ..., config.vocab_size]``
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
**ltr_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
|
Labels for computing the left-to-right language modeling loss (next word prediction).
|
||||||
|
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
Masked language modeling loss.
|
Masked language modeling loss.
|
||||||
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
**ltr_lm_loss**: (`optional`, returned when ``ltr_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
Next token prediction loss.
|
Next token prediction loss.
|
||||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
@ -833,7 +838,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
self.bert.embeddings.word_embeddings)
|
self.bert.embeddings.word_embeddings)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, ltr_lm_labels=None, ):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -852,22 +857,22 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
# 1. If a tensor that contains the indices of masked labels is provided,
|
# 1. If a tensor that contains the indices of masked labels is provided,
|
||||||
# the cross-entropy is the MLM cross-entropy that measures the likelihood
|
# the cross-entropy is the MLM cross-entropy that measures the likelihood
|
||||||
# of predictions for masked words.
|
# of predictions for masked words.
|
||||||
# 2. If `lm_label` is provided we are in a causal scenario where we
|
# 2. If `ltr_lm_labels` is provided we are in a causal scenario where we
|
||||||
# try to predict the next word for each input in the encoder.
|
# try to predict the next token for each input in the decoder.
|
||||||
if masked_lm_labels is not None:
|
if masked_lm_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
|
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
|
||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
outputs = (masked_lm_loss,) + outputs
|
outputs = (masked_lm_loss,) + outputs
|
||||||
|
|
||||||
if lm_labels is not None:
|
if ltr_lm_labels is not None:
|
||||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||||
lm_labels = lm_labels[:, 1:].contiguous()
|
ltr_lm_labels = ltr_lm_labels[:, 1:].contiguous()
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), ltr_lm_labels.view(-1))
|
||||||
outputs = (next_token_loss,) + outputs
|
outputs = (ltr_lm_loss,) + outputs
|
||||||
|
|
||||||
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
|
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||||
|
Loading…
Reference in New Issue
Block a user