mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 23:00:08 +06:00

* Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Implement the SinkCache through backward+forward rotations * Integrate (Sink)Cache with Llama FA2 * Set use_legacy_cache=True as default, allows for test passes * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Remove copy utility from deprecated OpenLlama * Match import style * manual rebase with main * Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> * move import * add default to model_kwargs.get('use_legacy_cache') * correct failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply PR suggestions * fix failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * PR comments * tmp commit * add docstrings * more tests, more docstrings, add to docs * derp * tmp commit * tmp dbg * more dbg * fix beam search bug * cache can be a list of tuples in some models * fix group beam search * all but sinkcache integration tests * fix sink cache and add hard integration test * now also compatible with input_embeds input * PR comments * add Cache support to Phi+FA2 * make fixup --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
388 lines
9.3 KiB
Markdown
388 lines
9.3 KiB
Markdown
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
|
rendered properly in your Markdown viewer.
|
|
|
|
-->
|
|
|
|
# Utilities for Generation
|
|
|
|
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
|
|
[`~generation.GenerationMixin.greedy_search`],
|
|
[`~generation.GenerationMixin.contrastive_search`],
|
|
[`~generation.GenerationMixin.sample`],
|
|
[`~generation.GenerationMixin.beam_search`],
|
|
[`~generation.GenerationMixin.beam_sample`],
|
|
[`~generation.GenerationMixin.group_beam_search`], and
|
|
[`~generation.GenerationMixin.constrained_beam_search`].
|
|
|
|
Most of those are only useful if you are studying the code of the generate methods in the library.
|
|
|
|
## Generate Outputs
|
|
|
|
The output of [`~generation.GenerationMixin.generate`] is an instance of a subclass of
|
|
[`~utils.ModelOutput`]. This output is a data structure containing all the information returned
|
|
by [`~generation.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
|
|
|
|
Here's an example:
|
|
|
|
```python
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
|
|
|
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
|
|
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
|
```
|
|
|
|
The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can
|
|
see in the documentation of that class below, it means it has the following attributes:
|
|
|
|
- `sequences`: the generated sequences of tokens
|
|
- `scores` (optional): the prediction scores of the language modelling head, for each generation step
|
|
- `hidden_states` (optional): the hidden states of the model, for each generation step
|
|
- `attentions` (optional): the attention weights of the model, for each generation step
|
|
|
|
Here we have the `scores` since we passed along `output_scores=True`, but we don't have `hidden_states` and
|
|
`attentions` because we didn't pass `output_hidden_states=True` or `output_attentions=True`.
|
|
|
|
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
|
|
will get `None`. Here for instance `generation_output.scores` are all the generated prediction scores of the
|
|
language modeling head, and `generation_output.attentions` is `None`.
|
|
|
|
When using our `generation_output` object as a tuple, it only keeps the attributes that don't have `None` values.
|
|
Here, for instance, it has two elements, `loss` then `logits`, so
|
|
|
|
```python
|
|
generation_output[:2]
|
|
```
|
|
|
|
will return the tuple `(generation_output.sequences, generation_output.scores)` for instance.
|
|
|
|
When using our `generation_output` object as a dictionary, it only keeps the attributes that don't have `None`
|
|
values. Here, for instance, it has two keys that are `sequences` and `scores`.
|
|
|
|
We document here all output types.
|
|
|
|
|
|
### PyTorch
|
|
|
|
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.SampleEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.SampleDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
|
|
|
|
### TensorFlow
|
|
|
|
[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.TFSampleEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.TFSampleDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput
|
|
|
|
[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput
|
|
|
|
[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput
|
|
|
|
### FLAX
|
|
|
|
[[autodoc]] generation.FlaxSampleOutput
|
|
|
|
[[autodoc]] generation.FlaxGreedySearchOutput
|
|
|
|
[[autodoc]] generation.FlaxBeamSearchOutput
|
|
|
|
## LogitsProcessor
|
|
|
|
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
|
|
generation.
|
|
|
|
### PyTorch
|
|
|
|
[[autodoc]] AlternatingCodebooksLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] ClassifierFreeGuidanceLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] EncoderNoRepeatNGramLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] EpsilonLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] EtaLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] ExponentialDecayLengthPenalty
|
|
- __call__
|
|
|
|
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] ForceTokensLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] HammingDiversityLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] InfNanRemoveLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] LogitNormalization
|
|
- __call__
|
|
|
|
[[autodoc]] LogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] LogitsProcessorList
|
|
- __call__
|
|
|
|
[[autodoc]] LogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] MinLengthLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] MinNewTokensLengthLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] NoBadWordsLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] NoRepeatNGramLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] PrefixConstrainedLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] SequenceBiasLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] SuppressTokensAtBeginLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] SuppressTokensLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TemperatureLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] TopKLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] TopPLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] TypicalLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] WhisperTimeStampLogitsProcessor
|
|
- __call__
|
|
|
|
### TensorFlow
|
|
|
|
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFForceTokensLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFLogitsProcessorList
|
|
- __call__
|
|
|
|
[[autodoc]] TFLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] TFMinLengthLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFNoBadWordsLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFNoRepeatNGramLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFSuppressTokensLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] TFTemperatureLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] TFTopKLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] TFTopPLogitsWarper
|
|
- __call__
|
|
|
|
### FLAX
|
|
|
|
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxForceTokensLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxLogitsProcessorList
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxMinLengthLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxSuppressTokensLogitsProcessor
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxTemperatureLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxTopKLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxTopPLogitsWarper
|
|
- __call__
|
|
|
|
[[autodoc]] FlaxWhisperTimeStampLogitsProcessor
|
|
- __call__
|
|
|
|
## StoppingCriteria
|
|
|
|
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations.
|
|
|
|
[[autodoc]] StoppingCriteria
|
|
- __call__
|
|
|
|
[[autodoc]] StoppingCriteriaList
|
|
- __call__
|
|
|
|
[[autodoc]] MaxLengthCriteria
|
|
- __call__
|
|
|
|
[[autodoc]] MaxTimeCriteria
|
|
- __call__
|
|
|
|
## Constraints
|
|
|
|
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations.
|
|
|
|
[[autodoc]] Constraint
|
|
|
|
[[autodoc]] PhrasalConstraint
|
|
|
|
[[autodoc]] DisjunctiveConstraint
|
|
|
|
[[autodoc]] ConstraintListState
|
|
|
|
## BeamSearch
|
|
|
|
[[autodoc]] BeamScorer
|
|
- process
|
|
- finalize
|
|
|
|
[[autodoc]] BeamSearchScorer
|
|
- process
|
|
- finalize
|
|
|
|
[[autodoc]] ConstrainedBeamSearchScorer
|
|
- process
|
|
- finalize
|
|
|
|
## Utilities
|
|
|
|
[[autodoc]] top_k_top_p_filtering
|
|
|
|
[[autodoc]] tf_top_k_top_p_filtering
|
|
|
|
## Streamers
|
|
|
|
[[autodoc]] TextStreamer
|
|
|
|
[[autodoc]] TextIteratorStreamer
|
|
|
|
## Caches
|
|
|
|
[[autodoc]] Cache
|
|
- update
|
|
|
|
[[autodoc]] DynamicCache
|
|
- update
|
|
- get_seq_length
|
|
- reorder_cache
|
|
- to_legacy_cache
|
|
- from_legacy_cache
|
|
|
|
[[autodoc]] SinkCache
|
|
- update
|
|
- get_seq_length
|
|
- reorder_cache
|