mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[qwen-omni] fix training (#37517)
* fix * add text config * fixup * fix docs
This commit is contained in:
parent
9167fadab9
commit
dcf6df5b0d
@ -112,8 +112,6 @@ input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
@ -332,6 +330,7 @@ quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serializatio
|
||||
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## Loading quantized models
|
||||
|
@ -1045,5 +1045,20 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
# Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model
|
||||
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||
# added. NOTE: currently method used only by vLLM
|
||||
return self.thinker_config.get_text_config()
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]
|
||||
|
@ -2503,7 +2503,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs
|
||||
@ -4384,6 +4386,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
self.speaker_map = {}
|
||||
if config.enable_audio_output:
|
||||
self.enable_talker()
|
||||
self.post_init()
|
||||
|
||||
def enable_talker(self):
|
||||
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
|
||||
|
@ -1030,6 +1030,21 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
# Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model
|
||||
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||
# added. NOTE: currently method used only by vLLM
|
||||
return self.thinker_config.get_text_config()
|
||||
|
||||
|
||||
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
config_class = Qwen2_5OmniConfig
|
||||
@ -2463,7 +2478,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs
|
||||
@ -4053,6 +4070,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
self.speaker_map = {}
|
||||
if config.enable_audio_output:
|
||||
self.enable_talker()
|
||||
self.post_init()
|
||||
|
||||
def enable_talker(self):
|
||||
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
|
||||
|
Loading…
Reference in New Issue
Block a user