mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix whisper kwargs and generation config (#30018)
* clean-up whisper kwargs * failing test
This commit is contained in:
parent
9b5a6450d4
commit
76fa17c166
@ -511,7 +511,6 @@ class WhisperGenerationMixin:
|
||||
self._set_language_and_task(
|
||||
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
||||
)
|
||||
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
|
||||
self._set_num_frames(
|
||||
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
||||
)
|
||||
@ -546,13 +545,13 @@ class WhisperGenerationMixin:
|
||||
logits_processor=logits_processor,
|
||||
begin_index=begin_index, # begin index is index of first generated decoder token
|
||||
is_shortform=is_shortform,
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
num_beams=generation_config.num_beams,
|
||||
)
|
||||
|
||||
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
|
||||
if is_shortform:
|
||||
if temperature is not None:
|
||||
kwargs["temperature"] = temperature
|
||||
generation_config.temperature = temperature
|
||||
|
||||
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
|
||||
if decoder_input_ids is None:
|
||||
@ -564,8 +563,8 @@ class WhisperGenerationMixin:
|
||||
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
|
||||
)
|
||||
|
||||
if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
||||
max_new_tokens = kwargs.get("max_new_tokens", 0)
|
||||
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
|
||||
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
||||
raise ValueError(
|
||||
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
|
||||
f"is {max_new_tokens}. Thus, the combined length of "
|
||||
@ -666,11 +665,10 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
|
||||
# 6.6 set max new tokens or max length
|
||||
kwargs = self._set_max_new_tokens_and_length(
|
||||
self._set_max_new_tokens_and_length(
|
||||
config=self.config,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
generation_config=generation_config,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# 6.7 Set current `begin_index` for all logit processors
|
||||
@ -770,9 +768,9 @@ class WhisperGenerationMixin:
|
||||
|
||||
for fallback_idx, temperature in enumerate(temperatures):
|
||||
generation_config.do_sample = temperature is not None and temperature > 0.0
|
||||
|
||||
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
||||
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
|
||||
if generation_config.do_sample:
|
||||
generation_config.num_beams = 1
|
||||
|
||||
generate_kwargs = copy.copy(kwargs)
|
||||
for key in ["do_sample", "temperature", "num_beams"]:
|
||||
@ -1095,11 +1093,8 @@ class WhisperGenerationMixin:
|
||||
task = getattr(generation_config, "task", None)
|
||||
language = getattr(generation_config, "language", None)
|
||||
|
||||
if kwargs.get("forced_decoder_ids", None) is not None:
|
||||
forced_decoder_ids = kwargs["forced_decoder_ids"]
|
||||
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||
|
||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||
if forced_decoder_ids is not None:
|
||||
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
||||
logger.warning_once(
|
||||
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
||||
@ -1107,8 +1102,6 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = config.forced_decoder_ids
|
||||
else:
|
||||
forced_decoder_ids = None
|
||||
|
||||
if forced_decoder_ids is not None and task is not None:
|
||||
logger.info(
|
||||
@ -1288,21 +1281,6 @@ class WhisperGenerationMixin:
|
||||
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _set_token_ids(generation_config, config, kwargs):
|
||||
eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
||||
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id
|
||||
)
|
||||
|
||||
generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id
|
||||
generation_config.decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
|
||||
if return_token_timestamps:
|
||||
@ -1313,7 +1291,6 @@ class WhisperGenerationMixin:
|
||||
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||
)
|
||||
|
||||
generation_config.num_frames = kwargs.pop("num_frames", None)
|
||||
|
||||
@staticmethod
|
||||
@ -1517,47 +1494,21 @@ class WhisperGenerationMixin:
|
||||
return decoder_input_ids, kwargs
|
||||
|
||||
@staticmethod
|
||||
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs):
|
||||
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config):
|
||||
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
|
||||
|
||||
passed_max_length = kwargs.pop("max_length", None)
|
||||
passed_max_new_tokens = kwargs.pop("max_new_tokens", None)
|
||||
max_length_config = getattr(generation_config, "max_length", None)
|
||||
max_new_tokens_config = getattr(generation_config, "max_new_tokens", None)
|
||||
|
||||
max_new_tokens = None
|
||||
max_length = None
|
||||
|
||||
# Make sure we don't get larger than `max_length`
|
||||
if passed_max_length is not None and passed_max_new_tokens is None:
|
||||
max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions)
|
||||
logger.info(
|
||||
f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment."
|
||||
)
|
||||
elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None:
|
||||
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
|
||||
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
|
||||
logger.info(
|
||||
f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment."
|
||||
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
|
||||
)
|
||||
elif (
|
||||
passed_max_new_tokens is not None
|
||||
and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
|
||||
generation_config.max_new_tokens is not None
|
||||
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
|
||||
):
|
||||
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
|
||||
elif (
|
||||
passed_max_new_tokens is None
|
||||
and max_new_tokens_config is not None
|
||||
and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions
|
||||
):
|
||||
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
|
||||
|
||||
if max_new_tokens is not None:
|
||||
kwargs["max_new_tokens"] = max_new_tokens
|
||||
|
||||
if max_length is not None:
|
||||
kwargs["max_length"] = max_length
|
||||
|
||||
return kwargs
|
||||
generation_config.max_new_tokens = max_new_tokens
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_compression_ratio(tokens, vocab_size):
|
||||
|
Loading…
Reference in New Issue
Block a user