Fix kwargs handling in generate_with_fallback (#29225)

* Fix generate_with_fallback **kwargs

* Change pop to get

* Delete keys from kwargs to prevent overriding generation_config

* Revert to passing kwargs by reference, but make a (shallow) copy

* dict -> copy.copy

* Add test_whisper_longform_multi_batch_beam
This commit is contained in:
Ondřej Cífka 2024-04-03 17:51:03 +02:00 committed by GitHub
parent 851f253f4d
commit bcd42c4af9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 2 deletions

View File

@ -755,6 +755,8 @@ class WhisperGenerationMixin:
do_condition_on_prev_tokens,
kwargs,
):
kwargs = copy.copy(kwargs)
# 6.6 Batch generate current chunk
seek_sequence_list = [None for _ in range(cur_bsz)]
seek_outputs_list = [None for _ in range(cur_bsz)]
@ -769,8 +771,12 @@ class WhisperGenerationMixin:
generation_config.do_sample = temperature is not None and temperature > 0.0
generation_config.temperature = temperature if generation_config.do_sample else 1.0
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
generate_kwargs = copy.copy(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
if key in generate_kwargs:
del generate_kwargs[key]
seek_outputs = super().generate(
segment_input,
generation_config,
@ -779,7 +785,7 @@ class WhisperGenerationMixin:
prefix_allowed_tokens_fn,
synced_gpus,
decoder_input_ids=decoder_input_ids,
**kwargs,
**generate_kwargs,
)
# post-process sequence tokens and outputs to be in list form

File diff suppressed because one or more lines are too long