Generate: Add new decoding strategy "DoLa" in .generate() (#29619)

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Yung-Sung Chuang 2024-07-09 09:37:38 -07:00 committed by GitHub
parent 99c0e55335
commit d094d8d9ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 530 additions and 5 deletions

View File

@ -178,7 +178,7 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
@ -213,11 +213,11 @@ I like rock music because it's loud and energetic. I like to listen to it when I
## Watermarking
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
the inner functioning of watermarking, it is recommended to refer to the paper.
The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
@ -484,3 +484,59 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
### DoLa Decoding
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
hallucinations of LLMs, as described in this paper of ICLR 2024 [DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models](https://arxiv.org/abs/2309.03883).
DoLa is achieved by contrasting the differences in logits obtained from final
layers versus earlier layers, thus amplify the factual knowledge localized to particular part of transformer layers.
Do the following two steps to activate DoLa decoding when calling the `model.generate` function:
1. Set the `dola_layers` argument, which can be either a string or a list of integers.
- If set to a string, it can be one of `low`, `high`.
- If set to a list of integers, it should be a list of layer indices between 0 and the total number of layers in the model. The 0-th layer is word embedding, and the 1st layer is the first transformer layer, and so on.
2. Set `repetition_penalty = 1.2` is suggested to reduce repetition in DoLa decoding.
See the following examples for DoLa decoding with the 32-layer LLaMA-7B model.
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
>>> model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16)
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> model.to(device)
>>> set_seed(42)
>>> text = "On what date was the Declaration of Independence officially signed?"
>>> inputs = tokenizer(text, return_tensors="pt").to(device)
# Vanilla greddy decoding
>>> vanilla_output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
>>> tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,']
# DoLa decoding with contrasting higher part of layers (layers 16,18,...,30)
>>> dola_high_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='high')
>>> tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nJuly 4, 1776, when the Continental Congress voted to separate from Great Britain. The 56 delegates to the Continental Congress signed the Declaration on August 2, 1776.']
# DoLa decoding with contrasting specific layers (layers 28 and 30)
>>> dola_custom_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers=[28,30], repetition_penalty=1.2)
>>> tokenizer.batch_decode(dola_custom_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nIt was officially signed on 2 August 1776, when 56 members of the Second Continental Congress, representing the original 13 American colonies, voted unanimously for the resolution for independence. The 2']
```
#### Understanding the `dola_layers` argument
`dola_layers` stands for the candidate layers in premature layer selection, as described in the DoLa paper. The selected premature layer will be contrasted with the final layer.
Setting `dola_layers` to `'low'` or `'high'` will select the lower or higher part of the layers to contrast, respectively.
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function.
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers.
The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper.

View File

@ -60,6 +60,7 @@ class GenerationMode(ExplicitEnum):
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
DOLA_GENERATION = "dola_generation"
# Beam methods
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
@ -81,6 +82,7 @@ class GenerationConfig(PushToHubMixin):
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- *dola decoding* if `dola_layers` is passed to `.generate()`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
@ -305,6 +307,18 @@ class GenerationConfig(PushToHubMixin):
max_matching_ngram_size (`int`, *optional*, default to `None`):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
> Parameters specific to the caching mechanism:
cache_implementation (`str`, *optional*, default to `None`):
@ -397,6 +411,9 @@ class GenerationConfig(PushToHubMixin):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
# DoLa generation
self.dola_layers = kwargs.pop("dola_layers", None)
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
@ -495,6 +512,16 @@ class GenerationConfig(PushToHubMixin):
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
# DoLa generation may extend some generation modes
if self.dola_layers is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.DOLA_GENERATION
else:
raise ValueError(
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
"is only supported with Greedy Search and Sample."
)
return generation_mode
def validate(self, is_init=False):
@ -700,6 +727,17 @@ class GenerationConfig(PushToHubMixin):
"`generate()` (or a pipeline) directly."
)
# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
dola_decoding_wrong_parameter_msg = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
)
warnings.warn(
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
UserWarning,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@ -20,9 +20,11 @@ import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.nn import functional as F
from ..cache_utils import (
Cache,
@ -1908,6 +1910,28 @@ class GenerationMixin:
streamer=streamer,
**model_kwargs,
)
elif generation_mode == GenerationMode.DOLA_GENERATION:
if self._is_stateful:
# DoLa decoding was not designed for stateful models, and would require some changes
raise ValueError(
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
)
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
result = self._dola_decoding(
input_ids,
dola_layers=generation_config.dola_layers,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if not model_kwargs["use_cache"]:
@ -2189,6 +2213,225 @@ class GenerationMixin:
)
return self._contrastive_search(*args, **kwargs)
def _dola_decoding(
self,
input_ids: torch.LongTensor,
dola_layers: Union[str, List[int]],
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"] = None,
logits_warper: Optional[LogitsProcessorList] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be
used for decoder-only text models.
The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language
Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
dola_layers (`Union[str, List[int]]`):
The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which
means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices
to be used for candidate layers. The 0-th layer is the word embedding layer of the model.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
if self.config.is_encoder_decoder:
raise ValueError("DoLa decoding is only available for decoder-only models.")
# init values
pad_token_id = generation_config.pad_token_id
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
this_peer_finished = False
# prepare layers for DoLa decoding
final_layer = self.config.num_hidden_layers
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
# as the early exit from word embeddings will become identity function
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
# layer otherwise. Notice that DoLa does not help shallow models much.
if not self.config.tie_word_embeddings:
start_layer = 0
elif final_layer > 2:
start_layer = 2
elif final_layer == 2:
start_layer = 1
else:
start_layer = 0
# For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)`
# are used for `'low'` and `'high'` layers, respectively.
# For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for
# `'low'` and `'high'` layers, respectively.
if isinstance(dola_layers, str) and dola_layers == "low":
if start_layer == final_layer // 2:
candidate_premature_layers = [start_layer]
else:
candidate_premature_layers = (
list(range(start_layer, final_layer // 2, 2))
if final_layer <= 40
else list(range(start_layer, 20, 2))
)
elif isinstance(dola_layers, str) and dola_layers == "high":
candidate_premature_layers = (
list(range(final_layer // 2, final_layer, 2))
if final_layer <= 40
else list(range(final_layer - 20, final_layer, 2))
)
# Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers.
elif isinstance(dola_layers, list):
candidate_premature_layers = [i for i in dola_layers if i < final_layer]
else:
raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.")
lm_head = self.get_output_embeddings()
if lm_head is None:
raise ValueError("DoLa is not supported for models that don't have output embeddings.")
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=True,
)
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone()
final_logits = outputs.logits[:, -1, :]
candidate_premature_logits = {}
for candidate_premature_layer in candidate_premature_layers:
candidate_premature_logits[candidate_premature_layer] = lm_head(
outputs.hidden_states[candidate_premature_layer][:, -1, :]
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = _dola_select_contrast(
candidate_premature_layers, candidate_premature_logits, final_logits
)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample: # sample
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (final_layer_next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
if do_sample: # sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else: # argmax
next_tokens = torch.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
@torch.no_grad()
def _contrastive_search(
self,
@ -4197,3 +4440,75 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
# Return a new object of the inferred class with the concatenated attributes
return model_output_cls(**concatenated_data)
def _relative_top_filter(
scores: torch.FloatTensor,
baseline_scores: torch.FloatTensor,
relative_top: float = 0.1,
filter_value: float = -float("Inf"),
base_filter_value=-1e-3,
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235
Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution.
"""
scores_normalized = scores.log_softmax(dim=-1)
baseline_scores_normalized = baseline_scores.log_softmax(dim=-1)
sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True)
min_thresh = sorted_logits[..., min_tokens_to_keep - 1]
probs_max = torch.max(scores_normalized, dim=-1).values
probs_thresh = probs_max + np.log(relative_top)
probs_thresh = torch.min(min_thresh, probs_thresh)
probs_thresh = probs_thresh.unsqueeze(-1)
baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value
scores_normalized[scores_normalized < probs_thresh] = filter_value
return scores_normalized, baseline_scores_normalized
def _dola_select_contrast(
candidate_premature_layers: List[int],
candidate_premature_logits: Dict[int, torch.FloatTensor],
final_logits: torch.FloatTensor,
) -> torch.FloatTensor:
if len(candidate_premature_layers) == 1:
base_logits = candidate_premature_logits[candidate_premature_layers[0]]
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
logits = final_logits - base_logits
return logits
# 1. Stacking all premature_layers into a new dimension
stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0)
# 2. Calculate the softmax values for mature_layer and all premature_layers
# shape: (batch_size, vocab_size)
softmax_mature_layer = F.softmax(final_logits, dim=-1)
# shape: (num_premature_layers, batch_size, vocab_size)
softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1)
# 3. Calculate the average distribution
# shape: (num_premature_layers, batch_size, vocab_size)
avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers)
# 4. Calculate log-softmax for the KL divergence
# shape: (batch_size, vocab_size)
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1)
# shape: (num_premature_layers, batch_size, vocab_size)
log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1)
# 5. Calculate the KL divergences and then the JS divergences
# shape: (num_premature_layers, batch_size)
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1)
# shape: (num_premature_layers, batch_size)
kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1)
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)
# 6. Reduce the batchmean
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
base_logits = candidate_premature_logits[premature_layer]
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
logits = final_logits - base_logits
return logits

View File

@ -1264,6 +1264,55 @@ class GenerationTesterMixin:
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_dola_decoding_sample(self):
# TODO (joao): investigate skips, try to reduce incompatibilities
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support DoLa decoding")
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.")
if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]):
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, input_ids, attention_mask = self._get_input_ids_and_config()
# Some models don't support the cache and returning past_key_values
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
self.skipTest("DoLa is not supported for encoder-decoder models")
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if model.get_output_embeddings() is None:
self.skipTest("DoLa is not supported for models that don't have output embeddings")
# Sets dola generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see b)
"num_beams": 1,
"do_sample": True,
"output_scores": True,
"output_logits": True,
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
}
generation_kwargs.update({"dola_layers": "low"})
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with

View File

@ -839,7 +839,6 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
@slow
@ -898,3 +897,24 @@ class GemmaIntegrationTest(unittest.TestCase):
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
def test_model_2b_bf16_dola(self):
model_id = "google/gemma-2b"
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXTS = [
"Hello I am doing an experiment and need to get the mass of a block. The problem is, it has no scale",
"Hi today we have the review for a <strong>2016/2017</strong> season of",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(
**inputs, max_new_tokens=20, do_sample=False, dola_layers="low", repetition_penalty=1.2
)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@ -703,6 +703,29 @@ class LlamaIntegrationTest(unittest.TestCase):
)
)
@slow
def test_model_7b_dola_generation(self):
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXT_COMPLETION = (
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
"understanding of space and time."
)
prompt = "Simply put, the theory of relativity states that "
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
)
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# greedy generation outputs
generated_ids = model.generate(
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
@require_torch_gpu
@require_read_token

View File

@ -555,6 +555,30 @@ class MistralIntegrationTest(unittest.TestCase):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
@slow
def test_model_7b_dola_generation(self):
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXT_COMPLETION = (
"""My favourite condiment is 100% ketchup. I love it on everything, and Im not ash"""
)
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
generated_ids = model.generate(
input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes
@slow
@require_flash_attn