Bart can make decoder_input_ids from labels (#6758)

This commit is contained in:
Sam Shleifer 2020-08-31 16:16:47 -04:00 committed by GitHub
parent b9772897ec
commit 367235ee52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"facebook/mbart-large-en-ro",
# See all BART models at https://huggingface.co/models?filter=bart
]
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
BART_START_DOCSTRING = r"""
@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
if labels is not None:
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model(
input_ids,