Bart: update example for #3140 compatibility (#3233)

* Update bart example docs
This commit is contained in:
Sam Shleifer 2020-03-12 10:36:37 -04:00 committed by GitHub
parent 72768b6b9c
commit 2e81b9d8d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 4 deletions

View File

@ -1,4 +1,4 @@
### Get the CNN/Daily Mail Data
### Get the CNN Data
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
```bash
@ -32,6 +32,7 @@ unzip stanford-corenlp-full-2018-10-05.zip
cd stanford-corenlp-full-2018-10-05
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
```
Then run `ptb_tokenize` on `test.target` and your generated hypotheses.
### Rouge Setup
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
I also needed to run `sudo apt-get install libxml-parser-perl`

View File

@ -27,9 +27,11 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
attention_mask=dct["attention_mask"].to(device),
num_beams=4,
length_penalty=2.0,
max_length=140,
min_length=55,
max_length=142, # +2 from original because we start at step=1 and stop before max_length
min_length=56, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
early_stopping=True,
do_sample=False,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:

View File

@ -45,6 +45,20 @@ BART_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
BART_GENERATION_EXAMPLE = r"""
Examples::
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
BART_INPUTS_DOCSTRING = r"""
@ -855,7 +869,8 @@ class BartModel(PretrainedBartModel):
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING,
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING + BART_GENERATION_EXAMPLE,
)
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"