mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix kwargs
handling in generate_with_fallback
(#29225)
* Fix generate_with_fallback **kwargs * Change pop to get * Delete keys from kwargs to prevent overriding generation_config * Revert to passing kwargs by reference, but make a (shallow) copy * dict -> copy.copy * Add test_whisper_longform_multi_batch_beam
This commit is contained in:
parent
851f253f4d
commit
bcd42c4af9
@ -755,6 +755,8 @@ class WhisperGenerationMixin:
|
||||
do_condition_on_prev_tokens,
|
||||
kwargs,
|
||||
):
|
||||
kwargs = copy.copy(kwargs)
|
||||
|
||||
# 6.6 Batch generate current chunk
|
||||
seek_sequence_list = [None for _ in range(cur_bsz)]
|
||||
seek_outputs_list = [None for _ in range(cur_bsz)]
|
||||
@ -769,8 +771,12 @@ class WhisperGenerationMixin:
|
||||
generation_config.do_sample = temperature is not None and temperature > 0.0
|
||||
|
||||
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
||||
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
|
||||
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
|
||||
|
||||
generate_kwargs = copy.copy(kwargs)
|
||||
for key in ["do_sample", "temperature", "num_beams"]:
|
||||
if key in generate_kwargs:
|
||||
del generate_kwargs[key]
|
||||
seek_outputs = super().generate(
|
||||
segment_input,
|
||||
generation_config,
|
||||
@ -779,7 +785,7 @@ class WhisperGenerationMixin:
|
||||
prefix_allowed_tokens_fn,
|
||||
synced_gpus,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
**kwargs,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
# post-process sequence tokens and outputs to be in list form
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user