mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: starcoder 🤜 🤛 assisted generation (#23182)
* starcoder has joined the chat * indexing that works for all
This commit is contained in:
parent
dbc12269ed
commit
bbfb9fc22b
@ -4221,6 +4221,9 @@ class GenerationMixin:
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
# other auxiliary variables
|
||||
max_len = stopping_criteria[0].max_length
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
if synced_gpus:
|
||||
@ -4235,7 +4238,7 @@ class GenerationMixin:
|
||||
|
||||
# Assistant: main logic start
|
||||
cur_len = input_ids.shape[-1]
|
||||
max_len = stopping_criteria[0].max_length
|
||||
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
|
||||
|
||||
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||
@ -4244,7 +4247,7 @@ class GenerationMixin:
|
||||
for _ in range(int(assistant_model.max_assistant_tokens)):
|
||||
# 1.1. use the assistant model to obtain the next candidate logits
|
||||
if "assistant_past_key_values" in model_kwargs:
|
||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||
@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
|
||||
if model.config.multi_query:
|
||||
for idx in range(len(past_key_values)):
|
||||
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
|
||||
else:
|
||||
for idx in range(len(past_key_values)):
|
||||
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
|
||||
else:
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
|
@ -1473,7 +1473,7 @@ class GenerationTesterMixin:
|
||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
):
|
||||
return
|
||||
|
||||
@ -1529,7 +1529,7 @@ class GenerationTesterMixin:
|
||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
):
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user