mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[MusicGen] Add streamer to generate (#25320)
* [MusicGen] Add streamer to generate * add to for cond generation * add test * finish * torch only * fix type hint * yield audio chunks * fix typehint * remove test
This commit is contained in:
parent
866df66fe4
commit
0dd06c3f78
@ -18,7 +18,7 @@ import inspect
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -48,6 +48,9 @@ from ..auto.modeling_auto import AutoModel
|
||||
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...generation.streamers import BaseStreamer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "MusicgenConfig"
|
||||
@ -1185,6 +1188,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -1225,6 +1229,9 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@ -1336,6 +1343,9 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# stash the delay mask so that we don't have to recompute it in each forward pass
|
||||
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
|
||||
|
||||
@ -1387,6 +1397,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1412,6 +1423,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -2186,6 +2198,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -2226,6 +2239,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@ -2368,6 +2384,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
# stash the delay mask so that we don't have to recompute in each forward pass
|
||||
model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
|
||||
|
||||
# input_ids are ready to be placed on the streamer (if used)
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 7. determine generation mode
|
||||
is_greedy_gen_mode = (
|
||||
(generation_config.num_beams == 1)
|
||||
@ -2416,6 +2436,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -2442,6 +2463,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user