From 9af1b6a80adbac906ba770d23ddf95a147f2f0a0 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 17 Jun 2024 10:09:27 +0500 Subject: [PATCH] Musicgen special tokens in tensors (#31420) fix --- src/transformers/models/musicgen/modeling_musicgen.py | 4 ++++ .../models/musicgen_melody/modeling_musicgen_melody.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 8c126f5d809..15d97d61e0f 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache @@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel): inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 2b6ff1b6be7..8bf622af8b7 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache @@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache