Musicgen special tokens in tensors (#31420)

fix
This commit is contained in:
Raushan Turganbay 2024-06-17 10:09:27 +05:00 committed by GitHub
parent eed9ed6798
commit 9af1b6a80a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 0 deletions

View File

@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache

View File

@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache