mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix assisted decoding assistant model inputs (#27503)
* fix assisted decoding attention_cat * fix attention_mask for assisted decoding * fix attention_mask len * fix attn len * Use a more clean way to prepare assistant models inputs * fix param meaning * fix param name * fix assistant model inputs * update token type ids * fix assistant kwargs copy * add encoder-decoder tests of assisted decoding * check if assistant kwargs contains updated keys * revert test * fix whisper tests * fix assistant kwargs * revert whisper test * delete _extend funcs
This commit is contained in:
parent
307cf3a2ab
commit
1d7f406e19
@ -1391,43 +1391,6 @@ class GenerationMixin:
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
|
||||
if self.config.is_encoder_decoder:
|
||||
key = "decoder_attention_mask"
|
||||
else:
|
||||
key = "attention_mask"
|
||||
|
||||
if key not in model_kwargs:
|
||||
return model_kwargs
|
||||
|
||||
mask = model_kwargs[key]
|
||||
mask_extension_length = new_mask_length - mask.shape[1]
|
||||
|
||||
if mask_extension_length < 0:
|
||||
raise ValueError("Cannot extend attention mask to a length less than it already is")
|
||||
|
||||
model_kwargs[key] = torch.cat(
|
||||
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
|
||||
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
|
||||
return model_kwargs
|
||||
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
|
||||
extension_length = new_length - token_type_ids.shape[1]
|
||||
token_type_copies = final_token_type.repeat(1, extension_length)
|
||||
model_kwargs["token_type_ids"] = torch.cat(
|
||||
[model_kwargs["token_type_ids"], token_type_copies],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@ -4505,11 +4468,6 @@ class GenerationMixin:
|
||||
else:
|
||||
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||
|
||||
# check if assistant model accepts encoder_outputs
|
||||
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
||||
inspect.signature(assistant_model.forward).parameters.keys()
|
||||
)
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
@ -4547,20 +4505,32 @@ class GenerationMixin:
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
# prepare assistant model's keys of inputs
|
||||
assistant_kwargs = copy.copy(model_kwargs)
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
# both are encoder-decoder
|
||||
input_ids_key = "decoder_input_ids"
|
||||
attention_key = "decoder_attention_mask"
|
||||
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
|
||||
elif "assistant_encoder_outputs" in assistant_kwargs:
|
||||
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
|
||||
input_ids_key = "input_ids"
|
||||
attention_key = "attention_mask"
|
||||
assistant_kwargs["attention_mask"] = assistant_kwargs.get(
|
||||
"decoder_attention_mask",
|
||||
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
|
||||
)
|
||||
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
|
||||
else:
|
||||
# both are decoder-only
|
||||
input_ids_key = "input_ids"
|
||||
attention_key = "attention_mask"
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
# other auxiliary variables
|
||||
max_len = stopping_criteria[0].max_length
|
||||
assistant_kv_indexing = (
|
||||
1
|
||||
if "bloom" in assistant_model.__class__.__name__.lower()
|
||||
or (
|
||||
assistant_model.config.architectures is not None
|
||||
and "bloom" in assistant_model.config.architectures[0].lower()
|
||||
)
|
||||
else 0
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
@ -4582,44 +4552,21 @@ class GenerationMixin:
|
||||
# need access to the assistant cache to secure strong speedups.
|
||||
candidate_input_ids = input_ids
|
||||
for _ in range(int(num_assistant_tokens)):
|
||||
# 1.1. use the assistant model to obtain the next candidate logits
|
||||
if "assistant_past_key_values" in model_kwargs:
|
||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_outputs = assistant_model(
|
||||
decoder_input_ids=assist_inputs,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
encoder_kwargs = {}
|
||||
# 1.1 prepare assistant model inputs
|
||||
assistant_inputs = assistant_model.prepare_inputs_for_generation(
|
||||
candidate_input_ids,
|
||||
**assistant_kwargs,
|
||||
)
|
||||
|
||||
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||
# 1.2. check if the input ids length is correct
|
||||
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
|
||||
if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2):
|
||||
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
|
||||
|
||||
assistant_model_outputs = assistant_model(
|
||||
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
|
||||
)
|
||||
else:
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_outputs = assistant_model(
|
||||
decoder_input_ids=candidate_input_ids,
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
encoder_kwargs = {}
|
||||
# 1.3. use the assistant model to obtain the next candidate logits
|
||||
assistant_model_outputs = assistant_model(**assistant_inputs)
|
||||
|
||||
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||
|
||||
assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
|
||||
|
||||
# 1.2. greedily select the next candidate token
|
||||
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
||||
# 1.4. greedily select the next candidate token
|
||||
if len(logits_processor) > 0:
|
||||
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
||||
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
||||
@ -4627,7 +4574,13 @@ class GenerationMixin:
|
||||
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||
|
||||
# 1.3. stop assistant generation on EOS
|
||||
# 1.5. update assistant model inputs
|
||||
if assistant_kwargs.get(attention_key, None) is not None:
|
||||
mask = assistant_kwargs[attention_key]
|
||||
assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
|
||||
assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
|
||||
|
||||
# 1.6. stop assistant generation on EOS
|
||||
if eos_token_id_tensor is not None:
|
||||
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
last_assistant_token_is_eos = (
|
||||
@ -4646,8 +4599,10 @@ class GenerationMixin:
|
||||
|
||||
# 2.1. Prepare the model inputs
|
||||
candidate_kwargs = copy.copy(model_kwargs)
|
||||
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
candidate_kwargs = _prepare_attention_mask(
|
||||
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
||||
)
|
||||
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
||||
|
||||
@ -4699,8 +4654,8 @@ class GenerationMixin:
|
||||
# 5.3. Discard past key values relative to unused assistant tokens
|
||||
new_cache_size = new_cur_len - 1
|
||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
|
||||
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
|
||||
assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
||||
assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1
|
||||
) # the assistant does not have the token after the last match, hence the -1
|
||||
|
||||
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||
@ -4761,6 +4716,12 @@ class GenerationMixin:
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# Update assistant_kwargs for the assistant's next round of generations
|
||||
assistant_kwargs = _prepare_attention_mask(
|
||||
assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder
|
||||
)
|
||||
assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
@ -4938,3 +4899,37 @@ def _ranking_fast(
|
||||
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
||||
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
||||
return selected_idx
|
||||
|
||||
|
||||
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
|
||||
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
|
||||
|
||||
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
|
||||
if mask_key not in model_kwargs:
|
||||
return model_kwargs
|
||||
|
||||
mask = model_kwargs[mask_key]
|
||||
mask_length_diff = new_length - mask.shape[1]
|
||||
|
||||
if mask_length_diff < 0:
|
||||
model_kwargs[mask_key] = mask[:, :mask_length_diff]
|
||||
elif mask_length_diff > 0:
|
||||
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
|
||||
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
|
||||
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
|
||||
return model_kwargs
|
||||
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
|
||||
type_length_diff = new_length - token_type_ids.shape[1]
|
||||
|
||||
if type_length_diff < 0:
|
||||
token_type_ids = token_type_ids[:, :type_length_diff]
|
||||
elif type_length_diff > 0:
|
||||
token_type_copies = final_token_type.repeat(1, type_length_diff)
|
||||
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
|
||||
return model_kwargs
|
||||
|
@ -348,10 +348,6 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
|
||||
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
|
||||
|
||||
@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
@ -726,10 +726,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
|
||||
class SwitchTransformersEncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
@ -1036,10 +1036,6 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
model.config.update(model.config.task_specific_params[task])
|
||||
|
Loading…
Reference in New Issue
Block a user