mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Generate: Add new decoding strategy "DoLa" in .generate()
(#29619)
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
parent
99c0e55335
commit
d094d8d9ec
@ -178,7 +178,7 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
|
||||
|
||||
The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
|
||||
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
|
||||
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
|
||||
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
|
||||
|
||||
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
|
||||
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
|
||||
@ -213,11 +213,11 @@ I like rock music because it's loud and energetic. I like to listen to it when I
|
||||
|
||||
## Watermarking
|
||||
|
||||
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
|
||||
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
|
||||
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
|
||||
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
|
||||
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
|
||||
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
|
||||
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
|
||||
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
|
||||
the inner functioning of watermarking, it is recommended to refer to the paper.
|
||||
|
||||
The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
|
||||
@ -484,3 +484,59 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
||||
|
||||
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
|
||||
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
|
||||
### DoLa Decoding
|
||||
|
||||
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
|
||||
hallucinations of LLMs, as described in this paper of ICLR 2024 [DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models](https://arxiv.org/abs/2309.03883).
|
||||
|
||||
DoLa is achieved by contrasting the differences in logits obtained from final
|
||||
layers versus earlier layers, thus amplify the factual knowledge localized to particular part of transformer layers.
|
||||
|
||||
Do the following two steps to activate DoLa decoding when calling the `model.generate` function:
|
||||
1. Set the `dola_layers` argument, which can be either a string or a list of integers.
|
||||
- If set to a string, it can be one of `low`, `high`.
|
||||
- If set to a list of integers, it should be a list of layer indices between 0 and the total number of layers in the model. The 0-th layer is word embedding, and the 1st layer is the first transformer layer, and so on.
|
||||
2. Set `repetition_penalty = 1.2` is suggested to reduce repetition in DoLa decoding.
|
||||
|
||||
See the following examples for DoLa decoding with the 32-layer LLaMA-7B model.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16)
|
||||
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
>>> model.to(device)
|
||||
>>> set_seed(42)
|
||||
|
||||
>>> text = "On what date was the Declaration of Independence officially signed?"
|
||||
>>> inputs = tokenizer(text, return_tensors="pt").to(device)
|
||||
|
||||
# Vanilla greddy decoding
|
||||
>>> vanilla_output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
|
||||
>>> tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
||||
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,']
|
||||
|
||||
# DoLa decoding with contrasting higher part of layers (layers 16,18,...,30)
|
||||
>>> dola_high_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='high')
|
||||
>>> tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
||||
['\nJuly 4, 1776, when the Continental Congress voted to separate from Great Britain. The 56 delegates to the Continental Congress signed the Declaration on August 2, 1776.']
|
||||
|
||||
# DoLa decoding with contrasting specific layers (layers 28 and 30)
|
||||
>>> dola_custom_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers=[28,30], repetition_penalty=1.2)
|
||||
>>> tokenizer.batch_decode(dola_custom_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
||||
['\nIt was officially signed on 2 August 1776, when 56 members of the Second Continental Congress, representing the original 13 American colonies, voted unanimously for the resolution for independence. The 2']
|
||||
```
|
||||
|
||||
#### Understanding the `dola_layers` argument
|
||||
|
||||
`dola_layers` stands for the candidate layers in premature layer selection, as described in the DoLa paper. The selected premature layer will be contrasted with the final layer.
|
||||
|
||||
Setting `dola_layers` to `'low'` or `'high'` will select the lower or higher part of the layers to contrast, respectively.
|
||||
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively.
|
||||
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively.
|
||||
- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function.
|
||||
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers.
|
||||
|
||||
The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper.
|
||||
|
@ -60,6 +60,7 @@ class GenerationMode(ExplicitEnum):
|
||||
GREEDY_SEARCH = "greedy_search"
|
||||
SAMPLE = "sample"
|
||||
ASSISTED_GENERATION = "assisted_generation"
|
||||
DOLA_GENERATION = "dola_generation"
|
||||
# Beam methods
|
||||
BEAM_SEARCH = "beam_search"
|
||||
BEAM_SAMPLE = "beam_sample"
|
||||
@ -81,6 +82,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
|
||||
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
|
||||
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||
- *dola decoding* if `dola_layers` is passed to `.generate()`
|
||||
|
||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
|
||||
@ -305,6 +307,18 @@ class GenerationConfig(PushToHubMixin):
|
||||
max_matching_ngram_size (`int`, *optional*, default to `None`):
|
||||
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
||||
|
||||
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
|
||||
|
||||
dola_layers (`str` or `List[int]`, *optional*):
|
||||
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
||||
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
||||
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
||||
layers up to the last 20 layers.
|
||||
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
||||
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
||||
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
||||
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
|
||||
|
||||
> Parameters specific to the caching mechanism:
|
||||
|
||||
cache_implementation (`str`, *optional*, default to `None`):
|
||||
@ -397,6 +411,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
||||
|
||||
# DoLa generation
|
||||
self.dola_layers = kwargs.pop("dola_layers", None)
|
||||
|
||||
# Cache implementation
|
||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||
self.cache_config = kwargs.pop("cache_config", None)
|
||||
@ -495,6 +512,16 @@ class GenerationConfig(PushToHubMixin):
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
)
|
||||
|
||||
# DoLa generation may extend some generation modes
|
||||
if self.dola_layers is not None:
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.DOLA_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def validate(self, is_init=False):
|
||||
@ -700,6 +727,17 @@ class GenerationConfig(PushToHubMixin):
|
||||
"`generate()` (or a pipeline) directly."
|
||||
)
|
||||
|
||||
# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
dola_decoding_wrong_parameter_msg = (
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
|
||||
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
|
||||
)
|
||||
warnings.warn(
|
||||
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
@ -20,9 +20,11 @@ import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..cache_utils import (
|
||||
Cache,
|
||||
@ -1908,6 +1910,28 @@ class GenerationMixin:
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif generation_mode == GenerationMode.DOLA_GENERATION:
|
||||
if self._is_stateful:
|
||||
# DoLa decoding was not designed for stateful models, and would require some changes
|
||||
raise ValueError(
|
||||
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
|
||||
)
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
result = self._dola_decoding(
|
||||
input_ids,
|
||||
dola_layers=generation_config.dola_layers,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||
if not model_kwargs["use_cache"]:
|
||||
@ -2189,6 +2213,225 @@ class GenerationMixin:
|
||||
)
|
||||
return self._contrastive_search(*args, **kwargs)
|
||||
|
||||
def _dola_decoding(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
dola_layers: Union[str, List[int]],
|
||||
logits_processor: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be
|
||||
used for decoder-only text models.
|
||||
The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language
|
||||
Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024.
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
dola_layers (`Union[str, List[int]]`):
|
||||
The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which
|
||||
means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices
|
||||
to be used for candidate layers. The 0-th layer is the word embedding layer of the model.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
|
||||
or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
||||
# init values
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
do_sample = generation_config.do_sample
|
||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
||||
raise ValueError(
|
||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
||||
f"{logits_warper})."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size = input_ids.shape[0]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
# prepare layers for DoLa decoding
|
||||
final_layer = self.config.num_hidden_layers
|
||||
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
|
||||
# as the early exit from word embeddings will become identity function
|
||||
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
|
||||
# layer otherwise. Notice that DoLa does not help shallow models much.
|
||||
if not self.config.tie_word_embeddings:
|
||||
start_layer = 0
|
||||
elif final_layer > 2:
|
||||
start_layer = 2
|
||||
elif final_layer == 2:
|
||||
start_layer = 1
|
||||
else:
|
||||
start_layer = 0
|
||||
|
||||
# For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)`
|
||||
# are used for `'low'` and `'high'` layers, respectively.
|
||||
# For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for
|
||||
# `'low'` and `'high'` layers, respectively.
|
||||
if isinstance(dola_layers, str) and dola_layers == "low":
|
||||
if start_layer == final_layer // 2:
|
||||
candidate_premature_layers = [start_layer]
|
||||
else:
|
||||
candidate_premature_layers = (
|
||||
list(range(start_layer, final_layer // 2, 2))
|
||||
if final_layer <= 40
|
||||
else list(range(start_layer, 20, 2))
|
||||
)
|
||||
elif isinstance(dola_layers, str) and dola_layers == "high":
|
||||
candidate_premature_layers = (
|
||||
list(range(final_layer // 2, final_layer, 2))
|
||||
if final_layer <= 40
|
||||
else list(range(final_layer - 20, final_layer, 2))
|
||||
)
|
||||
# Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers.
|
||||
elif isinstance(dola_layers, list):
|
||||
candidate_premature_layers = [i for i in dola_layers if i < final_layer]
|
||||
else:
|
||||
raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.")
|
||||
|
||||
lm_head = self.get_output_embeddings()
|
||||
if lm_head is None:
|
||||
raise ValueError("DoLa is not supported for models that don't have output embeddings.")
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone()
|
||||
final_logits = outputs.logits[:, -1, :]
|
||||
candidate_premature_logits = {}
|
||||
for candidate_premature_layer in candidate_premature_layers:
|
||||
candidate_premature_logits[candidate_premature_layer] = lm_head(
|
||||
outputs.hidden_states[candidate_premature_layer][:, -1, :]
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = _dola_select_contrast(
|
||||
candidate_premature_layers, candidate_premature_logits, final_logits
|
||||
)
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
if do_sample: # sample
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_logits:
|
||||
raw_logits += (final_layer_next_token_logits,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
if do_sample: # sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else: # argmax
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if has_eos_stopping_criteria:
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
return GenerateDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
@torch.no_grad()
|
||||
def _contrastive_search(
|
||||
self,
|
||||
@ -4197,3 +4440,75 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
|
||||
# Return a new object of the inferred class with the concatenated attributes
|
||||
return model_output_cls(**concatenated_data)
|
||||
|
||||
|
||||
def _relative_top_filter(
|
||||
scores: torch.FloatTensor,
|
||||
baseline_scores: torch.FloatTensor,
|
||||
relative_top: float = 0.1,
|
||||
filter_value: float = -float("Inf"),
|
||||
base_filter_value=-1e-3,
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235
|
||||
Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution.
|
||||
"""
|
||||
scores_normalized = scores.log_softmax(dim=-1)
|
||||
baseline_scores_normalized = baseline_scores.log_softmax(dim=-1)
|
||||
sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True)
|
||||
min_thresh = sorted_logits[..., min_tokens_to_keep - 1]
|
||||
probs_max = torch.max(scores_normalized, dim=-1).values
|
||||
probs_thresh = probs_max + np.log(relative_top)
|
||||
probs_thresh = torch.min(min_thresh, probs_thresh)
|
||||
probs_thresh = probs_thresh.unsqueeze(-1)
|
||||
baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value
|
||||
scores_normalized[scores_normalized < probs_thresh] = filter_value
|
||||
return scores_normalized, baseline_scores_normalized
|
||||
|
||||
|
||||
def _dola_select_contrast(
|
||||
candidate_premature_layers: List[int],
|
||||
candidate_premature_logits: Dict[int, torch.FloatTensor],
|
||||
final_logits: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
if len(candidate_premature_layers) == 1:
|
||||
base_logits = candidate_premature_logits[candidate_premature_layers[0]]
|
||||
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
|
||||
logits = final_logits - base_logits
|
||||
return logits
|
||||
|
||||
# 1. Stacking all premature_layers into a new dimension
|
||||
stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0)
|
||||
|
||||
# 2. Calculate the softmax values for mature_layer and all premature_layers
|
||||
# shape: (batch_size, vocab_size)
|
||||
softmax_mature_layer = F.softmax(final_logits, dim=-1)
|
||||
# shape: (num_premature_layers, batch_size, vocab_size)
|
||||
softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1)
|
||||
|
||||
# 3. Calculate the average distribution
|
||||
# shape: (num_premature_layers, batch_size, vocab_size)
|
||||
avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers)
|
||||
|
||||
# 4. Calculate log-softmax for the KL divergence
|
||||
# shape: (batch_size, vocab_size)
|
||||
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1)
|
||||
# shape: (num_premature_layers, batch_size, vocab_size)
|
||||
log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1)
|
||||
|
||||
# 5. Calculate the KL divergences and then the JS divergences
|
||||
# shape: (num_premature_layers, batch_size)
|
||||
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1)
|
||||
# shape: (num_premature_layers, batch_size)
|
||||
kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1)
|
||||
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)
|
||||
|
||||
# 6. Reduce the batchmean
|
||||
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
|
||||
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
|
||||
|
||||
base_logits = candidate_premature_logits[premature_layer]
|
||||
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
|
||||
logits = final_logits - base_logits
|
||||
return logits
|
||||
|
@ -1264,6 +1264,55 @@ class GenerationTesterMixin:
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_dola_decoding_sample(self):
|
||||
# TODO (joao): investigate skips, try to reduce incompatibilities
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support DoLa decoding")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]):
|
||||
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
|
||||
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# Some models don't support the cache and returning past_key_values
|
||||
if not hasattr(config, "use_cache"):
|
||||
config.use_cache = False
|
||||
else:
|
||||
config.use_cache = True
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
if model.get_output_embeddings() is None:
|
||||
self.skipTest("DoLa is not supported for models that don't have output embeddings")
|
||||
# Sets dola generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
# the main model is correct
|
||||
generation_kwargs = {
|
||||
"eos_token_id": -1, # see a)
|
||||
"max_new_tokens": 4, # see b)
|
||||
"num_beams": 1,
|
||||
"do_sample": True,
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
generation_kwargs.update({"dola_layers": "low"})
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
|
||||
|
||||
def test_assisted_decoding_sample(self):
|
||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||
|
@ -839,7 +839,6 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@slow
|
||||
@ -898,3 +897,24 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
|
||||
def test_model_2b_bf16_dola(self):
|
||||
model_id = "google/gemma-2b"
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing an experiment and need to get the mass of a block. The problem is, it has no scale",
|
||||
"Hi today we have the review for a <strong>2016/2017</strong> season of",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, dola_layers="low", repetition_penalty=1.2
|
||||
)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
@ -703,6 +703,29 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
|
||||
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
|
||||
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
|
||||
"understanding of space and time."
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(
|
||||
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
|
@ -555,6 +555,30 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"""My favourite condiment is 100% ketchup. I love it on everything, and I’m not ash"""
|
||||
)
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(
|
||||
input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
|
Loading…
Reference in New Issue
Block a user