Generate: Add assisted generation (#22211)

* working mvp

* remove breakpoint

* fix commit

* standardize outputs

* tmp commit

* tests almost ready

* tmp commit

* skip a few models

* Add streaming; Docs and examples

* document limitations

* PR commits

* Amy PR comments
This commit is contained in:
Joao Gante 2023-04-18 17:36:56 +01:00 committed by GitHub
parent 90247d3e01
commit 78cda46f17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 623 additions and 26 deletions

View File

@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
### Assisted Generation
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
supported, and doesn't support batched inputs.
<!-- TODO: add link to the blog post about assisted generation when it exists -->
To enable assisted generation, set the `assistant_model` argument with a model.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

View File

@ -73,9 +73,9 @@ from .stopping_criteria import (
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from .streamers import BaseStreamer
logger = logging.get_logger(__name__)
@ -1146,6 +1146,7 @@ class GenerationMixin:
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
@ -1196,10 +1197,14 @@ class GenerationMixin:
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@ -1411,6 +1416,14 @@ class GenerationMixin:
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_assisted_greedy_gen_mode = False
if assistant_model is not None:
if not is_greedy_gen_mode:
raise ValueError(
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation "
"is only supported with Greedy Search."
)
is_assisted_greedy_gen_mode = True
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
@ -1449,11 +1462,47 @@ class GenerationMixin:
generation_config=generation_config, stopping_criteria=stopping_criteria
)
# 10. go into different generation modes
if is_assisted_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing assisted greedy search, "
f"but is {generation_config.num_return_sequences}."
)
if batch_size > 1:
raise ValueError("Assisted generation is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("Assisted generation requires `use_cache=True`")
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
if assistant_model.config.is_encoder_decoder:
assistant_model_kwargs = copy.deepcopy(model_kwargs)
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
)
assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, assistant_model_kwargs, model_input_name
)
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
# 12. run assisted greedy search
return self.assisted_greedy_search(
input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search."
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
@ -1473,9 +1522,11 @@ class GenerationMixin:
elif is_contrastive_search_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" contrastive search."
"num_return_sequences has to be 1 when doing contrastive search, "
f"but is {generation_config.num_return_sequences}."
)
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")
return self.contrastive_search(
input_ids,
@ -1745,7 +1796,7 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
@ -2112,7 +2163,7 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
@ -2368,7 +2419,7 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
@ -2646,7 +2697,7 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
synced_gpus: bool = False,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
@ -2970,7 +3021,7 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
synced_gpus: bool = False,
**model_kwargs,
) -> Union[BeamSampleOutput, torch.LongTensor]:
r"""
@ -3302,7 +3353,7 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
synced_gpus: bool = False,
**model_kwargs,
):
r"""
@ -3994,6 +4045,468 @@ class GenerationMixin:
else:
return sequence_outputs["sequences"]
def assisted_greedy_search(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
):
r"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
</Tip>
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
>>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.assisted_greedy_search(
... input_ids,
... assistant_model=assistant_model,
... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria,
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
# Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None and pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# Assistant: main logic start
cur_len = input_ids.shape[-1]
max_len = stopping_criteria[0].max_length
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(assistant_model.max_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
tmp_inputs = candidate_input_ids[:, -new_token_len:]
tmp_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=tmp_inputs,
decoder_attention_mask=tmp_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
tmp_inputs,
attention_mask=tmp_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(candidate_input_ids)
# 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
if len(logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 1.3. stop assistant generation on EOS
if eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
else:
last_assistant_token_is_eos = False
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
if "past_key_values" in model_kwargs:
og_model_attn = torch.ones_like(candidate_input_ids)
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=og_model_input_ids,
decoder_attention_mask=og_model_attn,
past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
else:
outputs = self(
og_model_input_ids,
attention_mask=og_model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
else:
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 3. Obtain the argmax from the original model logits.
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1]
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum()
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if n_matches == int(assistant_model.max_assistant_tokens):
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# 6. Update variables according to the number of matching assistant tokens.
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
n_matches = min(n_matches, max_len - cur_len - 1)
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches]
new_cur_len = input_ids.shape[-1]
if streamer is not None:
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches])
# 6.2. Discard past key values relative to unused assistant tokens
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len)
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len
)
# 6.3. Extract the logits for the next token
next_token_scores = new_logits[:, n_matches, :]
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
next_tokens = torch.argmax(next_token_scores, dim=-1)
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache
# update.
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
if output_scores:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
if "past_key_values" not in model_kwargs:
last_matching_idx = new_cur_len - 1
prompt_length = cur_len
else:
last_matching_idx = n_matches
prompt_length = 0
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
prompt_length,
last_matching_idx,
is_decoder_attention=True,
)
else:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
prompt_length,
last_matching_idx,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx
)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GreedySearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return GreedySearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
past_key_values[idx][1][:, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
else:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
return past_key_values
def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token.
"""
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
# prompt.
if prompt_length > 0:
new_tuple = ()
for layer in new_outputs:
last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :prompt_length, :last_dim_size],)
outputs += (new_tuple,)
for i in range(prompt_length, last_matching_idx + 1):
new_tuple = ()
for layer in new_outputs:
last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,)
return outputs
def top_k_top_p_filtering(
logits: torch.FloatTensor,

View File

@ -79,14 +79,13 @@ class GenerationTesterMixin:
all_generative_model_classes = ()
input_name = "input_ids"
def _get_input_ids_and_config(self):
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
input_ids = input_ids[:batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
@ -99,7 +98,7 @@ class GenerationTesterMixin:
if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
return config, input_ids, attention_mask, max_length
@ -1458,6 +1457,66 @@ class GenerationTesterMixin:
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_assisted_greedy_search_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_greedy_search does not support `use_cache = False`
# - assisted_greedy_search does not support `batch_size > 1`
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail to pass this test, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
):
return
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
# be correct
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
assistant_model=model,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]

View File

@ -280,7 +280,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# overwrite from GenerationTesterMixin to solve problem
# with conflicting random seeds
def _get_input_ids_and_config(self):
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.attention_type = "original_full"
@ -288,10 +288,9 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3

View File

@ -303,7 +303,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
input_ids = input_ids[:max_batch_size, :, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
max_length = 4
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id

View File

@ -359,16 +359,15 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
def _get_input_ids_and_config(self):
def _get_input_ids_and_config(self, batch_size=3):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3
max_batch_size = 3
input_ids = input_ids[:max_batch_size, :, :]
# cut to half length & take max batch_size=batch_size
input_ids = input_ids[:batch_size, :, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
max_length = 4
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id