Generate: detect special architectures when loaded from PEFT (#24198)

This commit is contained in:
Joao Gante 2023-06-12 16:06:20 +01:00 committed by GitHub
parent 97527898da
commit 60b69f7de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4232,6 +4232,15 @@ class GenerationMixin:
# other auxiliary variables
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)
this_peer_finished = False # used by synced_gpus only
while True:
@ -4247,7 +4256,6 @@ class GenerationMixin:
# Assistant: main logic start
cur_len = input_ids.shape[-1]
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
@ -4512,7 +4520,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values = tuple(new_past)
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
@ -4521,7 +4532,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
# gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]