mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: detect special architectures when loaded from PEFT (#24198)
This commit is contained in:
parent
97527898da
commit
60b69f7de2
@ -4232,6 +4232,15 @@ class GenerationMixin:
|
||||
|
||||
# other auxiliary variables
|
||||
max_len = stopping_criteria[0].max_length
|
||||
assistant_kv_indexing = (
|
||||
1
|
||||
if "bloom" in assistant_model.__class__.__name__.lower()
|
||||
or (
|
||||
assistant_model.config.architectures is not None
|
||||
and "bloom" in assistant_model.config.architectures[0].lower()
|
||||
)
|
||||
else 0
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
@ -4247,7 +4256,6 @@ class GenerationMixin:
|
||||
|
||||
# Assistant: main logic start
|
||||
cur_len = input_ids.shape[-1]
|
||||
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
|
||||
|
||||
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||
@ -4512,7 +4520,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
|
||||
# bloom is special
|
||||
elif "bloom" in model.__class__.__name__.lower() or (
|
||||
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
|
||||
):
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
@ -4521,7 +4532,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
|
||||
# gptbigcode is too
|
||||
elif "gptbigcode" in model.__class__.__name__.lower() or (
|
||||
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
|
||||
):
|
||||
if model.config.multi_query:
|
||||
for idx in range(len(past_key_values)):
|
||||
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
|
||||
|
Loading…
Reference in New Issue
Block a user