[Docs] Improve PyTorch, Flax generate API (#15988)

* Move generate docs

* up

* Update docs/source/_toctree.yml

* correct

* correct some stuff

* correct tests

* more fixes

* finish generate

* add to doc stest

* finish

* finalize

* add warning to generate method
This commit is contained in:
Patrick von Platen 2022-03-10 11:54:45 +01:00 committed by GitHub
parent 0951d31788
commit 6ce11c2c0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 202 additions and 98 deletions

View File

@ -114,6 +114,8 @@
title: Logging
- local: main_classes/model
title: Models
- local: main_classes/text_generation
title: Text Generation
- local: main_classes/onnx
title: ONNX
- local: main_classes/optimizer_schedules

View File

@ -86,14 +86,6 @@ Due to Pytorch design, this functionality is only available for floating dtypes.
- push_to_hub
- all
## Generation
[[autodoc]] generation_utils.GenerationMixin
[[autodoc]] generation_tf_utils.TFGenerationMixin
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
## Pushing to the Hub
[[autodoc]] file_utils.PushToHubMixin

View File

@ -0,0 +1,39 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Generation
The methods for auto-regressive text generation, namely [`~generation_utils.GenerationMixin.generate`] (for the PyTorch models), [`~generation_tf_utils.TFGenerationMixin.generate`] (for the TensorFlow models) and [`~generation_flax_utils.FlaxGenerationMixin.generate`] (for the Flax/JAX models), are implemented in [`~generation_utils.GenerationMixin`], [`~generation_tf_utils.TFGenerationMixin`] and [`~generation_flax_utils.FlaxGenerationMixin`] respectively.
The `GenerationMixin` classes are inherited by the corresponding base model classes, *e.g.* [`PreTrainedModel`], [`TFPreTrainedModel`], and [`FlaxPreTrainedModel`] respectively, therefore exposing all
methods for auto-regressive text generation to every model class.
## GenerationMixn
[[autodoc]] generation_utils.GenerationMixin
- generate
- greedy_search
- sample
- beam_search
- beam_sample
- group_beam_search
- constrained_beam_search
## TFGenerationMixn
[[autodoc]] generation_tf_utils.TFGenerationMixin
- generate
## FlaxGenerationMixn
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
- generate

View File

@ -118,7 +118,16 @@ class BeamSearchState:
class FlaxGenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in [`FlaxPreTrainedModel`].
A class containing all functions for auto-regressive text generation, to be used as a mixin in
[`FlaxPreTrainedModel`].
The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
`num_beams=1` and `do_sample=False`.
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
and `do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
and `do_sample=False`.
"""
@staticmethod
@ -176,12 +185,23 @@ class FlaxGenerationMixin:
**model_kwargs,
):
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
and, multinomial sampling.
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
Apart from `input_ids`, all the arguments below will default to the value of the attribute of the same name
inside the [`PretrainedConfig`] of the model. The default values indicated are the default values of those
config.
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
`num_beams=1` and `do_sample=False`.
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
and `do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
and `do_sample=False`.
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
</Tip>
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
@ -236,7 +256,7 @@ class FlaxGenerationMixin:
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length

View File

@ -377,7 +377,21 @@ BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOu
class GenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in [`PreTrainedModel`].
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
if `constraints!=None` or `force_words_ids!=None`.
"""
def _prepare_model_inputs(
@ -847,18 +861,37 @@ class GenerationMixin:
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name inside
the [`PretrainedConfig`] of the model. The default values indicated are the default values of those config.
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling
[`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
`force_words_ids!=None`.
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
</Tip>
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Parameters:
inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
@ -997,66 +1030,56 @@ class GenerationMixin:
Examples:
Greedy Decoding:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> # do greedy decoding without providing a prompt
>>> outputs = model.generate(max_length=40)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> document = (
... "at least two people were killed in a suspected bomb attack on a passenger bus "
... "in the strife-torn southern philippines on monday , the military said."
... )
>>> # encode input context
>>> input_ids = tokenizer(document, return_tensors="pt").input_ids
>>> # generate 3 independent sequences using beam search decoding (5 beams)
>>> # with T5 encoder-decoder model conditioned on short news article.
>>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate 3 candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = AutoModelForCausalLM.from_pretrained("ctrl")
>>> # "Legal" is one of the control codes for ctrl
>>> input_context = "Legal My neighbor is"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "My cute dog"
>>> # get tokens of words that should not be generated
>>> bad_words_ids = tokenizer(
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
>>> ).input_ids
>>> # get tokens of words that we want generated
>>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate sequences without allowing bad_words to be generated
>>> outputs = model.generate(
... input_ids=input_ids,
... max_length=20,
... do_sample=True,
... bad_words_ids=bad_words_ids,
... force_words_ids=force_words_ids,
... )
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> # generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial Sampling:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> # sample up to 30 tokens
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
```
Beam-search decoding:
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
@ -1457,7 +1480,8 @@ class GenerationMixin:
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using greedy decoding.
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@ -1508,6 +1532,8 @@ class GenerationMixin:
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
@ -1516,26 +1542,30 @@ class GenerationMixin:
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
>>> outputs = model.greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
@ -1683,7 +1713,8 @@ class GenerationMixin:
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using multinomial sampling.
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@ -1739,7 +1770,10 @@ class GenerationMixin:
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
@ -1764,9 +1798,18 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample(
... input_ids,
... logits_processor=logits_processor,
... logits_warper=logits_warper,
... stopping_criteria=stopping_criteria,
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```"""
# init values
@ -1926,7 +1969,8 @@ class GenerationMixin:
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@ -2020,7 +2064,8 @@ class GenerationMixin:
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@ -2237,7 +2282,8 @@ class GenerationMixin:
**model_kwargs,
) -> Union[BeamSampleOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@ -2343,7 +2389,8 @@ class GenerationMixin:
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@ -2556,7 +2603,8 @@ class GenerationMixin:
**model_kwargs,
):
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Generates sequences of token ids for models with a language modeling head using **diverse beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
@ -2656,7 +2704,8 @@ class GenerationMixin:
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@ -2920,7 +2969,8 @@ class GenerationMixin:
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Generates sequences of token ids for models with a language modeling head using **constrained beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -3024,8 +3074,8 @@ class GenerationMixin:
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
# => ['Wie alter sind Sie?']
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt sind Sie?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()

View File

@ -28,5 +28,6 @@ src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/blenderbot/modeling_blenderbot.py
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/generation_utils.py
docs/source/quicktour.mdx
docs/source/task_summary.mdx
docs/source/task_summary.mdx