[whisper] pass attention_mask to generate_with_fallback() (#33145)

pass attention_mask to generate_with_fallback
This commit is contained in:
benniekiss 2024-08-28 03:53:58 -04:00 committed by GitHub
parent 3bfd3e4803
commit e0b87b0f40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -686,6 +686,7 @@ class WhisperGenerationMixin:
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
batch_size=batch_size,
attention_mask=attention_mask,
kwargs=kwargs,
)
@ -790,6 +791,7 @@ class WhisperGenerationMixin:
do_condition_on_prev_tokens,
is_shortform,
batch_size,
attention_mask,
kwargs,
):
kwargs = copy.copy(kwargs)
@ -837,6 +839,7 @@ class WhisperGenerationMixin:
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)