[whisper] fix multilingual fine-tuning (#30865)

* [whisper] fix multilingual fine-tuning

* config ids as well
This commit is contained in:
Sanchit Gandhi 2024-05-17 15:12:44 +01:00 committed by GitHub
parent 977ce58a78
commit 57edd84bdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -425,12 +425,8 @@ def main():
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update(
**{
"language": data_args.language,
"task": data_args.task,
}
)
model.generation_config.language = data_args.language
model.generation_config.task = data_args.task
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
@ -444,6 +440,9 @@ def main():
"Please use the `language` and `task` arguments instead"
)
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
else:
model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None
if model_args.suppress_tokens is not None:
logger.warning(