Generate: starcoder 🤜 🤛 assisted generation (#23182)

* starcoder has joined the chat

* indexing that works for all
This commit is contained in:
Joao Gante 2023-05-08 10:45:40 +01:00 committed by GitHub
parent dbc12269ed
commit bbfb9fc22b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 4 deletions

View File

@ -4221,6 +4221,9 @@ class GenerationMixin:
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
# other auxiliary variables
max_len = stopping_criteria[0].max_length
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
@ -4235,7 +4238,7 @@ class GenerationMixin:
# Assistant: main logic start
cur_len = input_ids.shape[-1]
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
@ -4244,7 +4247,7 @@ class GenerationMixin:
for _ in range(int(assistant_model.max_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
new_past.append(

View File

@ -1473,7 +1473,7 @@ class GenerationTesterMixin:
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
return
@ -1529,7 +1529,7 @@ class GenerationTesterMixin:
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
return