mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
parent
eed9ed6798
commit
9af1b6a80a
@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = input_ids.shape[0] // self.num_codebooks
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = input_ids.shape[0] // self.num_codebooks
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
Loading…
Reference in New Issue
Block a user