mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: Add assisted generation (#22211)
* working mvp * remove breakpoint * fix commit * standardize outputs * tmp commit * tests almost ready * tmp commit * skip a few models * Add streaming; Docs and examples * document limitations * PR commits * Amy PR comments
This commit is contained in:
parent
90247d3e01
commit
78cda46f17
@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other
|
||||
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
|
||||
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
||||
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
|
||||
|
||||
### Assisted Generation
|
||||
|
||||
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
|
||||
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
|
||||
supported, and doesn't support batched inputs.
|
||||
|
||||
<!-- TODO: add link to the blog post about assisted generation when it exists -->
|
||||
|
||||
To enable assisted generation, set the `assistant_model` argument with a model.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> prompt = "Alice and Bob"
|
||||
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
|
||||
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||
>>> outputs = model.generate(**inputs, assistant_model=assistant_model)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
@ -73,9 +73,9 @@ from .stopping_criteria import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from .streamers import BaseStreamer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1146,6 +1146,7 @@ class GenerationMixin:
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
@ -1196,10 +1197,14 @@ class GenerationMixin:
|
||||
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
||||
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
||||
generating before other GPUs. Otherwise it'll be set to `False`.
|
||||
assistant_model (`PreTrainedModel`, *optional*):
|
||||
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||
is much faster than running generation with the model you're calling generate from. As such, the
|
||||
assistant model should be much smaller.
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@ -1411,6 +1416,14 @@ class GenerationMixin:
|
||||
and not is_constraint_gen_mode
|
||||
and not is_contrastive_search_gen_mode
|
||||
)
|
||||
is_assisted_greedy_gen_mode = False
|
||||
if assistant_model is not None:
|
||||
if not is_greedy_gen_mode:
|
||||
raise ValueError(
|
||||
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation "
|
||||
"is only supported with Greedy Search."
|
||||
)
|
||||
is_assisted_greedy_gen_mode = True
|
||||
|
||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
@ -1449,11 +1462,47 @@ class GenerationMixin:
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
# 10. go into different generation modes
|
||||
if is_assisted_greedy_gen_mode:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
"num_return_sequences has to be 1 when doing assisted greedy search, "
|
||||
f"but is {generation_config.num_return_sequences}."
|
||||
)
|
||||
if batch_size > 1:
|
||||
raise ValueError("Assisted generation is only supported for batch_size = 1")
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("Assisted generation requires `use_cache=True`")
|
||||
|
||||
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
||||
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
|
||||
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
|
||||
)
|
||||
assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
|
||||
inputs_tensor, assistant_model_kwargs, model_input_name
|
||||
)
|
||||
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
||||
|
||||
# 12. run assisted greedy search
|
||||
return self.assisted_greedy_search(
|
||||
input_ids,
|
||||
assistant_model=assistant_model,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
if is_greedy_gen_mode:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||
" greedy search."
|
||||
"num_return_sequences has to be 1 when doing greedy search, "
|
||||
f"but is {generation_config.num_return_sequences}."
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
@ -1473,9 +1522,11 @@ class GenerationMixin:
|
||||
elif is_contrastive_search_gen_mode:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||
" contrastive search."
|
||||
"num_return_sequences has to be 1 when doing contrastive search, "
|
||||
f"but is {generation_config.num_return_sequences}."
|
||||
)
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||
|
||||
return self.contrastive_search(
|
||||
input_ids,
|
||||
@ -1745,7 +1796,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
synced_gpus: bool = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
||||
@ -2112,7 +2163,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
synced_gpus: bool = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||
@ -2368,7 +2419,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
synced_gpus: bool = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[SampleOutput, torch.LongTensor]:
|
||||
@ -2646,7 +2697,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
synced_gpus: bool = False,
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2970,7 +3021,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
synced_gpus: bool = False,
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -3302,7 +3353,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
synced_gpus: bool = False,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -3994,6 +4045,468 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def assisted_greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
assistant_model: "PreTrainedModel",
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: bool = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
|
||||
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
assistant_model (`PreTrainedModel`, *optional*):
|
||||
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||
is much faster than running generation with the model you're calling generate from. As such, the
|
||||
assistant model should be much smaller.
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more details.
|
||||
output_scores (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
|
||||
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... AutoModelForCausalLM,
|
||||
... LogitsProcessorList,
|
||||
... MinLengthLogitsProcessor,
|
||||
... StoppingCriteriaList,
|
||||
... MaxLengthCriteria,
|
||||
... )
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
|
||||
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
>>> input_prompt = "It might be possible to"
|
||||
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = LogitsProcessorList(
|
||||
... [
|
||||
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
|
||||
... ]
|
||||
... )
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
>>> outputs = model.assisted_greedy_search(
|
||||
... input_ids,
|
||||
... assistant_model=assistant_model,
|
||||
... logits_processor=logits_processor,
|
||||
... stopping_criteria=stopping_criteria,
|
||||
... )
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||
```"""
|
||||
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
|
||||
# Assistant: initialize assistant-related variables
|
||||
if not hasattr(assistant_model, "max_assistant_tokens"):
|
||||
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None and pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
if return_dict_in_generate is not None
|
||||
else self.generation_config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
# Assistant: main logic start
|
||||
cur_len = input_ids.shape[-1]
|
||||
max_len = stopping_criteria[0].max_length
|
||||
|
||||
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||
# need access to the assistant cache to secure strong speedups.
|
||||
candidate_input_ids = input_ids
|
||||
for _ in range(int(assistant_model.max_assistant_tokens)):
|
||||
# 1.1. use the assistant model to obtain the next candidate logits
|
||||
if "assistant_past_key_values" in model_kwargs:
|
||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||
tmp_inputs = candidate_input_ids[:, -new_token_len:]
|
||||
tmp_attn = torch.ones_like(candidate_input_ids)
|
||||
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_outputs = assistant_model(
|
||||
decoder_input_ids=tmp_inputs,
|
||||
decoder_attention_mask=tmp_attn,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
assistant_model_outputs = assistant_model(
|
||||
tmp_inputs,
|
||||
attention_mask=tmp_attn,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
)
|
||||
else:
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_outputs = assistant_model(
|
||||
decoder_input_ids=candidate_input_ids,
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
assistant_model_outputs = assistant_model(candidate_input_ids)
|
||||
|
||||
# 1.2. greedily select the next candidate token
|
||||
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
||||
if len(logits_processor) > 0:
|
||||
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
||||
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
||||
)
|
||||
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||
|
||||
# 1.3. stop assistant generation on EOS
|
||||
if eos_token_id_tensor is not None:
|
||||
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
last_assistant_token_is_eos = (
|
||||
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
||||
)
|
||||
if last_assistant_token_is_eos:
|
||||
break
|
||||
else:
|
||||
last_assistant_token_is_eos = False
|
||||
|
||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||
|
||||
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
||||
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
|
||||
if "past_key_values" in model_kwargs:
|
||||
og_model_attn = torch.ones_like(candidate_input_ids)
|
||||
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
||||
if self.config.is_encoder_decoder:
|
||||
outputs = self(
|
||||
decoder_input_ids=og_model_input_ids,
|
||||
decoder_attention_mask=og_model_attn,
|
||||
past_key_values=model_kwargs["past_key_values"],
|
||||
encoder_outputs=model_kwargs["encoder_outputs"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
else:
|
||||
outputs = self(
|
||||
og_model_input_ids,
|
||||
attention_mask=og_model_attn,
|
||||
past_key_values=model_kwargs["past_key_values"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
else:
|
||||
if self.config.is_encoder_decoder:
|
||||
outputs = self(
|
||||
decoder_input_ids=candidate_input_ids,
|
||||
encoder_outputs=model_kwargs["encoder_outputs"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
else:
|
||||
outputs = self(
|
||||
candidate_input_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
# 3. Obtain the argmax from the original model logits.
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length):
|
||||
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||
max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1]
|
||||
|
||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||
n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum()
|
||||
|
||||
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||
# cost of forecasting incorrect assistant tokens.
|
||||
if n_matches == int(assistant_model.max_assistant_tokens):
|
||||
assistant_model.max_assistant_tokens += 2.0
|
||||
else:
|
||||
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
|
||||
|
||||
# 6. Update variables according to the number of matching assistant tokens.
|
||||
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
|
||||
n_matches = min(n_matches, max_len - cur_len - 1)
|
||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||
n_matches -= 1
|
||||
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches]
|
||||
new_cur_len = input_ids.shape[-1]
|
||||
if streamer is not None:
|
||||
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches])
|
||||
|
||||
# 6.2. Discard past key values relative to unused assistant tokens
|
||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len)
|
||||
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len
|
||||
)
|
||||
|
||||
# 6.3. Extract the logits for the next token
|
||||
next_token_scores = new_logits[:, n_matches, :]
|
||||
|
||||
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
|
||||
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed
|
||||
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache
|
||||
# update.
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
||||
|
||||
if "past_key_values" not in model_kwargs:
|
||||
last_matching_idx = new_cur_len - 1
|
||||
prompt_length = cur_len
|
||||
else:
|
||||
last_matching_idx = n_matches
|
||||
prompt_length = 0
|
||||
|
||||
if output_attentions:
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions = _split_model_outputs(
|
||||
cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx
|
||||
)
|
||||
decoder_attentions = _split_model_outputs(
|
||||
decoder_attentions,
|
||||
outputs.decoder_attentions,
|
||||
prompt_length,
|
||||
last_matching_idx,
|
||||
is_decoder_attention=True,
|
||||
)
|
||||
else:
|
||||
decoder_attentions = _split_model_outputs(
|
||||
decoder_attentions,
|
||||
outputs.attentions,
|
||||
prompt_length,
|
||||
last_matching_idx,
|
||||
is_decoder_attention=True,
|
||||
)
|
||||
if output_hidden_states:
|
||||
if self.config.is_encoder_decoder:
|
||||
decoder_hidden_states = _split_model_outputs(
|
||||
decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx
|
||||
)
|
||||
else:
|
||||
decoder_hidden_states = _split_model_outputs(
|
||||
decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx
|
||||
)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return GreedySearchEncoderDecoderOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return GreedySearchDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
|
||||
def _crop_past_key_values(model, past_key_values, maximum_length):
|
||||
"""Crops the past key values up to a certain maximum length."""
|
||||
new_past = []
|
||||
if model.config.is_encoder_decoder:
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :maximum_length, :],
|
||||
past_key_values[idx][1][:, :, :maximum_length, :],
|
||||
past_key_values[idx][2],
|
||||
past_key_values[idx][3],
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :maximum_length],
|
||||
past_key_values[idx][1][:, :maximum_length, :],
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
else:
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :maximum_length, :],
|
||||
past_key_values[idx][1][:, :, :maximum_length, :],
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
return past_key_values
|
||||
|
||||
|
||||
def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False):
|
||||
"""
|
||||
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
||||
where each member corresponds to a single generated token.
|
||||
"""
|
||||
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
|
||||
# prompt.
|
||||
if prompt_length > 0:
|
||||
new_tuple = ()
|
||||
for layer in new_outputs:
|
||||
last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1]
|
||||
new_tuple += (layer[..., :prompt_length, :last_dim_size],)
|
||||
outputs += (new_tuple,)
|
||||
|
||||
for i in range(prompt_length, last_matching_idx + 1):
|
||||
new_tuple = ()
|
||||
for layer in new_outputs:
|
||||
last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1]
|
||||
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
||||
outputs += (new_tuple,)
|
||||
return outputs
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
|
@ -79,14 +79,13 @@ class GenerationTesterMixin:
|
||||
all_generative_model_classes = ()
|
||||
input_name = "input_ids"
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
@ -99,7 +98,7 @@ class GenerationTesterMixin:
|
||||
if "transfoxl" in config.__class__.__name__.lower():
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
|
||||
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@ -1458,6 +1457,66 @@ class GenerationTesterMixin:
|
||||
for output in (output_contrastive, output_generate):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_assisted_greedy_search_matches_greedy_search(self):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# It breaks the pattern in the tests above, for multiple reasons:
|
||||
# - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||
# prepare the assistant encoder outputs in the main generate body);
|
||||
# - assisted_greedy_search does not support `use_cache = False`
|
||||
# - assisted_greedy_search does not support `batch_size > 1`
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
# may fix in the future: the following models fail to pass this test, and need model-specific fixes
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
|
||||
):
|
||||
return
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
|
||||
# be correct
|
||||
output_assisted = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
assistant_model=model,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
|
@ -280,7 +280,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
# overwrite from GenerationTesterMixin to solve problem
|
||||
# with conflicting random seeds
|
||||
def _get_input_ids_and_config(self):
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.attention_type = "original_full"
|
||||
|
||||
@ -288,10 +288,9 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
|
@ -303,7 +303,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
input_ids = input_ids[:max_batch_size, :, :]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
max_length = 4
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
@ -359,16 +359,15 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
def _get_input_ids_and_config(self, batch_size=3):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 3
|
||||
input_ids = input_ids[:max_batch_size, :, :]
|
||||
# cut to half length & take max batch_size=batch_size
|
||||
input_ids = input_ids[:batch_size, :, :]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
max_length = 4
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
Loading…
Reference in New Issue
Block a user