mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-08 07:10:06 +06:00
Fix whisper kwargs and generation config (#30018)
* clean-up whisper kwargs * failing test
This commit is contained in:
parent
9b5a6450d4
commit
76fa17c166
@ -511,7 +511,6 @@ class WhisperGenerationMixin:
|
|||||||
self._set_language_and_task(
|
self._set_language_and_task(
|
||||||
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
||||||
)
|
)
|
||||||
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
|
|
||||||
self._set_num_frames(
|
self._set_num_frames(
|
||||||
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
||||||
)
|
)
|
||||||
@ -546,13 +545,13 @@ class WhisperGenerationMixin:
|
|||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
begin_index=begin_index, # begin index is index of first generated decoder token
|
begin_index=begin_index, # begin index is index of first generated decoder token
|
||||||
is_shortform=is_shortform,
|
is_shortform=is_shortform,
|
||||||
num_beams=kwargs.get("num_beams", 1),
|
num_beams=generation_config.num_beams,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
|
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
|
||||||
if is_shortform:
|
if is_shortform:
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
kwargs["temperature"] = temperature
|
generation_config.temperature = temperature
|
||||||
|
|
||||||
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
|
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
@ -564,8 +563,8 @@ class WhisperGenerationMixin:
|
|||||||
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
|
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
|
||||||
max_new_tokens = kwargs.get("max_new_tokens", 0)
|
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
|
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
|
||||||
f"is {max_new_tokens}. Thus, the combined length of "
|
f"is {max_new_tokens}. Thus, the combined length of "
|
||||||
@ -666,11 +665,10 @@ class WhisperGenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6.6 set max new tokens or max length
|
# 6.6 set max new tokens or max length
|
||||||
kwargs = self._set_max_new_tokens_and_length(
|
self._set_max_new_tokens_and_length(
|
||||||
config=self.config,
|
config=self.config,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6.7 Set current `begin_index` for all logit processors
|
# 6.7 Set current `begin_index` for all logit processors
|
||||||
@ -770,9 +768,9 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
for fallback_idx, temperature in enumerate(temperatures):
|
for fallback_idx, temperature in enumerate(temperatures):
|
||||||
generation_config.do_sample = temperature is not None and temperature > 0.0
|
generation_config.do_sample = temperature is not None and temperature > 0.0
|
||||||
|
|
||||||
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
||||||
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
|
if generation_config.do_sample:
|
||||||
|
generation_config.num_beams = 1
|
||||||
|
|
||||||
generate_kwargs = copy.copy(kwargs)
|
generate_kwargs = copy.copy(kwargs)
|
||||||
for key in ["do_sample", "temperature", "num_beams"]:
|
for key in ["do_sample", "temperature", "num_beams"]:
|
||||||
@ -1095,11 +1093,8 @@ class WhisperGenerationMixin:
|
|||||||
task = getattr(generation_config, "task", None)
|
task = getattr(generation_config, "task", None)
|
||||||
language = getattr(generation_config, "language", None)
|
language = getattr(generation_config, "language", None)
|
||||||
|
|
||||||
if kwargs.get("forced_decoder_ids", None) is not None:
|
|
||||||
forced_decoder_ids = kwargs["forced_decoder_ids"]
|
|
||||||
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
|
|
||||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||||
|
if forced_decoder_ids is not None:
|
||||||
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
||||||
@ -1107,8 +1102,6 @@ class WhisperGenerationMixin:
|
|||||||
)
|
)
|
||||||
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
||||||
forced_decoder_ids = config.forced_decoder_ids
|
forced_decoder_ids = config.forced_decoder_ids
|
||||||
else:
|
|
||||||
forced_decoder_ids = None
|
|
||||||
|
|
||||||
if forced_decoder_ids is not None and task is not None:
|
if forced_decoder_ids is not None and task is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -1288,21 +1281,6 @@ class WhisperGenerationMixin:
|
|||||||
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
|
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _set_token_ids(generation_config, config, kwargs):
|
|
||||||
eos_token_id = kwargs.pop("eos_token_id", None)
|
|
||||||
decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
|
||||||
|
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id
|
|
||||||
decoder_start_token_id = (
|
|
||||||
decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id
|
|
||||||
generation_config.decoder_start_token_id = (
|
|
||||||
decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
|
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
|
||||||
if return_token_timestamps:
|
if return_token_timestamps:
|
||||||
@ -1313,7 +1291,6 @@ class WhisperGenerationMixin:
|
|||||||
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
||||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||||
)
|
)
|
||||||
|
|
||||||
generation_config.num_frames = kwargs.pop("num_frames", None)
|
generation_config.num_frames = kwargs.pop("num_frames", None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1517,47 +1494,21 @@ class WhisperGenerationMixin:
|
|||||||
return decoder_input_ids, kwargs
|
return decoder_input_ids, kwargs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs):
|
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config):
|
||||||
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
|
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
|
||||||
|
|
||||||
passed_max_length = kwargs.pop("max_length", None)
|
|
||||||
passed_max_new_tokens = kwargs.pop("max_new_tokens", None)
|
|
||||||
max_length_config = getattr(generation_config, "max_length", None)
|
|
||||||
max_new_tokens_config = getattr(generation_config, "max_new_tokens", None)
|
|
||||||
|
|
||||||
max_new_tokens = None
|
|
||||||
max_length = None
|
|
||||||
|
|
||||||
# Make sure we don't get larger than `max_length`
|
# Make sure we don't get larger than `max_length`
|
||||||
if passed_max_length is not None and passed_max_new_tokens is None:
|
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
|
||||||
max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions)
|
|
||||||
logger.info(
|
|
||||||
f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment."
|
|
||||||
)
|
|
||||||
elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None:
|
|
||||||
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
|
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment."
|
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
passed_max_new_tokens is not None
|
generation_config.max_new_tokens is not None
|
||||||
and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
|
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
|
||||||
):
|
):
|
||||||
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
|
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
|
||||||
elif (
|
generation_config.max_new_tokens = max_new_tokens
|
||||||
passed_max_new_tokens is None
|
|
||||||
and max_new_tokens_config is not None
|
|
||||||
and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions
|
|
||||||
):
|
|
||||||
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
|
|
||||||
|
|
||||||
if max_new_tokens is not None:
|
|
||||||
kwargs["max_new_tokens"] = max_new_tokens
|
|
||||||
|
|
||||||
if max_length is not None:
|
|
||||||
kwargs["max_length"] = max_length
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _retrieve_compression_ratio(tokens, vocab_size):
|
def _retrieve_compression_ratio(tokens, vocab_size):
|
||||||
|
Loading…
Reference in New Issue
Block a user