[MusicGen] Add streamer to generate (#25320)

* [MusicGen] Add streamer to generate

* add to for cond generation

* add test

* finish

* torch only

* fix type hint

* yield audio chunks

* fix typehint

* remove test
This commit is contained in:
Sanchit Gandhi 2023-09-14 15:59:09 +01:00 committed by GitHub
parent 866df66fe4
commit 0dd06c3f78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ import inspect
import math
import random
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -48,6 +48,9 @@ from ..auto.modeling_auto import AutoModel
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MusicgenConfig"
@ -1185,6 +1188,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
"""
@ -1225,6 +1229,9 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@ -1336,6 +1343,9 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
max_length=generation_config.max_length,
)
if streamer is not None:
streamer.put(input_ids.cpu())
# stash the delay mask so that we don't have to recompute it in each forward pass
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
@ -1387,6 +1397,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
@ -1412,6 +1423,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
@ -2186,6 +2198,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
"""
@ -2226,6 +2239,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@ -2368,6 +2384,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
# stash the delay mask so that we don't have to recompute in each forward pass
model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
# input_ids are ready to be placed on the streamer (if used)
if streamer is not None:
streamer.put(input_ids.cpu())
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
@ -2416,6 +2436,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
@ -2442,6 +2463,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)