[qwen-omni] fix training (#37517)

* fix

* add text config

* fixup

* fix docs
This commit is contained in:
Raushan Turganbay 2025-04-22 12:36:07 +02:00 committed by GitHub
parent 9167fadab9
commit dcf6df5b0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 4 deletions

View File

@ -112,8 +112,6 @@ input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoption>
<hfoption id="int4-weight-only">
@ -332,6 +330,7 @@ quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serializatio
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
```
</hfoption>
</hfoptions>
## Loading quantized models

View File

@ -1045,5 +1045,20 @@ class Qwen2_5OmniConfig(PretrainedConfig):
super().__init__(**kwargs)
@classmethod
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]

View File

@ -2503,7 +2503,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
)
if not return_dict:
output = (logits,) + outputs
@ -4384,6 +4386,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
self.speaker_map = {}
if config.enable_audio_output:
self.enable_talker()
self.post_init()
def enable_talker(self):
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)

View File

@ -1030,6 +1030,21 @@ class Qwen2_5OmniConfig(PretrainedConfig):
super().__init__(**kwargs)
@classmethod
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
config_class = Qwen2_5OmniConfig
@ -2463,7 +2478,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
)
if not return_dict:
output = (logits,) + outputs
@ -4053,6 +4070,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
self.speaker_map = {}
if config.enable_audio_output:
self.enable_talker()
self.post_init()
def enable_talker(self):
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)