Fix and improve documentation for LEDForConditionalGeneration (#12303)

* Replace conditional generation example (fixes #12268)

* Replace model in summarization example with finetuned checkpoint, adapt example text

* Fix typo in new summarization example

* Fix docstring formatting, add missing import statement to example
This commit is contained in:
Kilian Kluge 2021-06-22 15:58:13 +02:00 committed by GitHub
parent 1498eb9888
commit 032d56a435
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1436,17 +1436,43 @@ LED_START_DOCSTRING = r"""
LED_GENERATION_EXAMPLE = r"""
Summarization example::
>>> from transformers import LEDTokenizer, LEDForConditionalGeneration, LEDConfig
>>> import torch
>>> from transformers import LEDTokenizer, LEDForConditionalGeneration
>>> model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')
>>> tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')
>>> model = LEDForConditionalGeneration.from_pretrained('allenai/led-large-16384-arxiv')
>>> tokenizer = LEDTokenizer.from_pretrained('allenai/led-large-16384-arxiv')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
>>> ARTICLE_TO_SUMMARIZE = '''Transformers (Vaswani et al., 2017) have achieved state-of-the-art
... results in a wide range of natural language tasks including generative
... language modeling (Dai et al., 2019; Radford et al., 2019) and discriminative
... language understanding (Devlin et al., 2019). This success is partly due to
... the self-attention component which enables the network to capture contextual
... information from the entire sequence. While powerful, the memory and computational
... requirements of self-attention grow quadratically with sequence length, making
... it infeasible (or very expensive) to process long sequences.
...
... To address this limitation, we present Longformer, a modified Transformer
... architecture with a self-attention operation that scales linearly with the
... sequence length, making it versatile for processing long documents (Fig 1). This
... is an advantage for natural language tasks such as long document classification,
... question answering (QA), and coreference resolution, where existing approaches
... partition or shorten the long context into smaller sequences that fall within the
... typical 512 token limit of BERT-style pretrained models. Such partitioning could
... potentially result in loss of important cross-partition information, and to
... mitigate this problem, existing methods often rely on complex architectures to
... address such interactions. On the other hand, our proposed Longformer is able to
... build contextual representations of the entire context using multiple layers of
... attention, reducing the need for task-specific architectures.'''
>>> inputs = tokenizer.encode(ARTICLE_TO_SUMMARIZE, return_tensors='pt')
>>> # Global attention on the first token (cf. Beltagy et al. 2020)
>>> global_attention_mask = torch.zeros_like(inputs)
>>> global_attention_mask[:, 0] = 1
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
>>> summary_ids = model.generate(inputs, global_attention_mask=global_attention_mask,
... num_beams=3, max_length=32, early_stopping=True)
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
"""
LED_INPUTS_DOCSTRING = r"""
@ -2305,13 +2331,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
>>> model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')
>>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = logits[0, masked_index].softmax(dim=0)
>>> values, predictions = probs.topk(5)
>>> tokenizer.decode(predictions).split()
>>> prediction = model.generate(input_ids)[0]
>>> print(tokenizer.decode(prediction, skip_special_tokens=True))
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict