
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
9.6 KiB
Pegasus
Overview
The Pegasus model was proposed in PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu on Dec 18, 2019.
According to the abstract,
- Pegasus' pretraining task is intentionally similar to summarization: important sentences are removed/masked from an input document and are generated together as one output sequence from the remaining sentences, similar to an extractive summary.
- Pegasus achieves SOTA summarization performance on all 12 downstream tasks, as measured by ROUGE and human eval.
This model was contributed by sshleifer. The Authors' code can be found here.
Usage tips
-
Sequence-to-sequence model with the same encoder-decoder model architecture as BART. Pegasus is pre-trained jointly on two self-supervised objective functions: Masked Language Modeling (MLM) and a novel summarization specific pretraining objective, called Gap Sentence Generation (GSG).
- MLM: encoder input tokens are randomly replaced by a mask tokens and have to be predicted by the encoder (like in BERT)
- GSG: whole encoder input sentences are replaced by a second mask token and fed to the decoder, but which has a causal mask to hide the future words like a regular auto-regressive transformer decoder.
-
FP16 is not supported (help/ideas on this appreciated!).
-
The adafactor optimizer is recommended for pegasus fine-tuning.
Checkpoints
All the checkpoints are fine-tuned for summarization, besides pegasus-large, whence the other checkpoints are fine-tuned:
- Each checkpoint is 2.2 GB on disk and 568M parameters.
- FP16 is not supported (help/ideas on this appreciated!).
- Summarizing xsum in fp32 takes about 400ms/sample, with default parameters on a v100 GPU.
- Full replication results and correctly pre-processed data can be found in this Issue.
- Distilled checkpoints are described in this paper.
Implementation Notes
- All models are transformer encoder-decoders with 16 layers in each component.
- The implementation is completely inherited from [
BartForConditionalGeneration
] - Some key configuration differences:
- static, sinusoidal position embeddings
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix.
- more beams are used (
num_beams=8
)
- All pretrained pegasus checkpoints are the same besides three attributes:
tokenizer.model_max_length
(maximum input size),max_length
(the maximum number of tokens to generate) andlength_penalty
. - The code to convert checkpoints trained in the author's repo can be
found in
convert_pegasus_tf_to_pytorch.py
.
Usage Example
>>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
>>> import torch
>>> src_text = [
... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
... ]
... model_name = "google/pegasus-xsum"
... device = "cuda" if torch.cuda.is_available() else "cpu"
... tokenizer = PegasusTokenizer.from_pretrained(model_name)
... model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
... batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
... translated = model.generate(**batch)
... tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
... assert (
... tgt_text[0]
... == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
... )
Resources
- Script to fine-tune pegasus on the XSUM dataset. Data download instructions at examples/pytorch/summarization/.
- Causal language modeling task guide
- Translation task guide
- Summarization task guide
PegasusConfig
autodoc PegasusConfig
PegasusTokenizer
warning: add_tokens
does not work at the moment.
autodoc PegasusTokenizer
PegasusTokenizerFast
autodoc PegasusTokenizerFast
PegasusModel
autodoc PegasusModel - forward
PegasusForConditionalGeneration
autodoc PegasusForConditionalGeneration - forward
PegasusForCausalLM
autodoc PegasusForCausalLM - forward
TFPegasusModel
autodoc TFPegasusModel - call
TFPegasusForConditionalGeneration
autodoc TFPegasusForConditionalGeneration - call
FlaxPegasusModel
autodoc FlaxPegasusModel - call - encode - decode
FlaxPegasusForConditionalGeneration
autodoc FlaxPegasusForConditionalGeneration - call - encode - decode