From dcf6df5b0d186aa510c35d2901de398628ec209e Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 22 Apr 2025 12:36:07 +0200 Subject: [PATCH] [qwen-omni] fix training (#37517) * fix * add text config * fixup * fix docs --- docs/source/en/quantization/torchao.md | 3 +-- .../configuration_qwen2_5_omni.py | 15 ++++++++++++++ .../qwen2_5_omni/modeling_qwen2_5_omni.py | 5 ++++- .../qwen2_5_omni/modular_qwen2_5_omni.py | 20 ++++++++++++++++++- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 62e3723403b..f3153d9f226 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -112,8 +112,6 @@ input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") print(tokenizer.decode(output[0], skip_special_tokens=True)) ``` - - @@ -332,6 +330,7 @@ quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serializatio tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128") ``` + ## Loading quantized models diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 22873e16d07..f2e2333f675 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -1045,5 +1045,20 @@ class Qwen2_5OmniConfig(PretrainedConfig): super().__init__(**kwargs) + @classmethod + def get_text_config(self, decoder=False) -> "PretrainedConfig": + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config() + __all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"] diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index bda93a796e9..3a35b274904 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2503,7 +2503,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) if not return_dict: output = (logits,) + outputs @@ -4384,6 +4386,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation self.speaker_map = {} if config.enable_audio_output: self.enable_talker() + self.post_init() def enable_talker(self): self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 61ef59d9a16..b6494f7aa1b 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1030,6 +1030,21 @@ class Qwen2_5OmniConfig(PretrainedConfig): super().__init__(**kwargs) + @classmethod + def get_text_config(self, decoder=False) -> "PretrainedConfig": + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config() + class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): config_class = Qwen2_5OmniConfig @@ -2463,7 +2478,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) if not return_dict: output = (logits,) + outputs @@ -4053,6 +4070,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation self.speaker_map = {} if config.enable_audio_output: self.enable_talker() + self.post_init() def enable_talker(self): self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)