mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
parent
eed9ed6798
commit
9af1b6a80a
@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
batch_size = input_ids.shape[0] // self.num_codebooks
|
batch_size = input_ids.shape[0] // self.num_codebooks
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
batch_size = input_ids.shape[0] // self.num_codebooks
|
batch_size = input_ids.shape[0] // self.num_codebooks
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
Loading…
Reference in New Issue
Block a user