MusicGen Update (#27084)

* [MusicGen] Add stereo model

* safe serialization

* Update src/transformers/models/musicgen/modeling_musicgen.py

* split over 2 lines

* fix slow tests on cuda
This commit is contained in:
Sanchit Gandhi 2023-11-08 13:26:02 +00:00 committed by GitHub
parent 5ef650b0ae
commit f16ff0f07e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 244 additions and 28 deletions

View File

@ -57,6 +57,11 @@ Generation is limited by the sinusoidal positional embeddings to 30 second input
than 30 seconds of audio (1503 tokens), and input audio passed by Audio-Prompted Generation contributes to this limit so,
given an input of 20 seconds of audio, MusicGen cannot generate more than 10 seconds of additional audio.
Transformers supports both mono (1-channel) and stereo (2-channel) variants of MusicGen. The mono channel versions
generate a single set of codebooks. The stereo versions generate 2 sets of codebooks, 1 for each channel (left/right),
and each set of codebooks is decoded independently through the audio compression model. The audio streams for each
channel are combined to give the final stereo output.
### Unconditional Generation
The inputs for unconditional (or 'null') generation can be obtained through the method

View File

@ -75,6 +75,9 @@ class MusicgenDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied.
audio_channels (`int`, *optional*, defaults to 1
Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
audio stream for the left/right output channels. Mono models generate a single audio stream output.
"""
model_type = "musicgen_decoder"
keys_to_ignore_at_inference = ["past_key_values"]
@ -96,6 +99,7 @@ class MusicgenDecoderConfig(PretrainedConfig):
initializer_factor=0.02,
scale_embedding=False,
num_codebooks=4,
audio_channels=1,
pad_token_id=2048,
bos_token_id=2048,
eos_token_id=None,
@ -117,6 +121,11 @@ class MusicgenDecoderConfig(PretrainedConfig):
self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks
if audio_channels not in [1, 2]:
raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
self.audio_channels = audio_channels
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,

View File

@ -88,32 +88,48 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small":
if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
# default config values
hidden_size = 1024
num_hidden_layers = 24
num_attention_heads = 16
elif checkpoint == "medium":
elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
hidden_size = 1536
num_hidden_layers = 48
num_attention_heads = 24
elif checkpoint == "large":
elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
hidden_size = 2048
num_hidden_layers = 48
num_attention_heads = 32
else:
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
raise ValueError(
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, got {checkpoint}."
)
if "stereo" in checkpoint:
audio_channels = 2
num_codebooks = 8
else:
audio_channels = 1
num_codebooks = 4
config = MusicgenDecoderConfig(
hidden_size=hidden_size,
ffn_dim=hidden_size * 4,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_codebooks=num_codebooks,
audio_channels=audio_channels,
)
return config
@torch.no_grad()
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
def convert_musicgen_checkpoint(
checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False
):
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
decoder_config = decoder_config_from_checkpoint(checkpoint)
@ -146,18 +162,20 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
# check we can do a forward pass
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * 4, -1)
input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1)
with torch.no_grad():
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
if logits.shape != (8, 1, 2048):
if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048):
raise ValueError("Incorrect shape for logits")
# now construct the processor
tokenizer = AutoTokenizer.from_pretrained("t5-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
feature_extractor = AutoFeatureExtractor.from_pretrained(
"facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels
)
processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
@ -173,12 +191,12 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
if pytorch_dump_folder is not None:
Path(pytorch_dump_folder).mkdir(exist_ok=True)
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
model.save_pretrained(pytorch_dump_folder)
model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization)
processor.save_pretrained(pytorch_dump_folder)
if repo_id:
logger.info(f"Pushing model {checkpoint} to {repo_id}")
model.push_to_hub(repo_id)
model.push_to_hub(repo_id, safe_serialization=safe_serialization)
processor.push_to_hub(repo_id)
@ -189,7 +207,10 @@ if __name__ == "__main__":
"--checkpoint",
default="small",
type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
"`['small', 'medium', 'large']` for the mono checkpoints, or "
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
"for the stereo checkpoints.",
)
parser.add_argument(
"--pytorch_dump_folder",
@ -204,6 +225,11 @@ if __name__ == "__main__":
parser.add_argument(
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
)
args = parser.parse_args()
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)

View File

@ -1077,21 +1077,33 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
if max_length < 2 * channel_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
for codebook in range(channel_codebooks):
if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
@ -1856,6 +1868,11 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
"disabled by setting `chunk_length=None` in the audio encoder."
)
if self.config.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
# mono input through encodec that we convert to stereo
audio_codes = audio_codes.repeat_interleave(2, dim=2)
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
# Decode
@ -2074,12 +2091,42 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = input_values
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
if self.decoder.config.audio_channels == 1:
encoder_kwargs[model_input_name] = input_values
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
audio_codes = audio_encoder_outputs.audio_codes
audio_scales = audio_encoder_outputs.audio_scales
audio_codes = audio_encoder_outputs.audio_codes
frames, bsz, codebooks, seq_len = audio_codes.shape
frames, bsz, codebooks, seq_len = audio_codes.shape
else:
if input_values.shape[1] != 2:
raise ValueError(
f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
)
encoder_kwargs[model_input_name] = input_values[:, :1, :]
audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
audio_codes_left = audio_encoder_outputs_left.audio_codes
audio_scales_left = audio_encoder_outputs_left.audio_scales
encoder_kwargs[model_input_name] = input_values[:, 1:, :]
audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
audio_codes_right = audio_encoder_outputs_right.audio_codes
audio_scales_right = audio_encoder_outputs_right.audio_scales
frames, bsz, codebooks, seq_len = audio_codes_left.shape
# copy alternating left/right channel codes into stereo codebook
audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
audio_codes[:, :, ::2, :] = audio_codes_left
audio_codes[:, :, 1::2, :] = audio_codes_right
if audio_scales_left != [None] or audio_scales_right != [None]:
audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
else:
audio_scales = [None] * bsz
if frames != 1:
raise ValueError(
@ -2090,7 +2137,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
model_kwargs["decoder_input_ids"] = decoder_input_ids
model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales
model_kwargs["audio_scales"] = audio_scales
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
@ -2433,16 +2480,25 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if audio_scales is None:
audio_scales = [None] * batch_size
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
)
if self.decoder.config.audio_channels == 1:
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
else:
codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
output_values_left = codec_outputs_left.audio_values
codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
output_values_right = codec_outputs_right.audio_values
output_values = torch.cat([output_values_left, output_values_right], dim=1)
if generation_config.return_dict_in_generate:
outputs.sequences = output_values.audio_values
outputs.sequences = output_values
return outputs
else:
return output_values.audio_values
return output_values
def get_unconditional_inputs(self, num_samples=1):
"""

View File

@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def prepare_musicgen_inputs_dict(
config,
@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""
@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase):
output_values.shape == (2, 1, 36480)
) # input values take shape 32000 and we generate from there
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
@require_torch
class MusicgenStereoIntegrationTests(unittest.TestCase):
@cached_property
def model(self):
return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small").to(torch_device)
@cached_property
def processor(self):
return MusicgenProcessor.from_pretrained("facebook/musicgen-stereo-small")
@slow
def test_generate_unconditional_greedy(self):
model = self.model
# only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12)
# fmt: off
EXPECTED_VALUES_LEFT = torch.tensor(
[
0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013,
-0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099,
]
)
EXPECTED_VALUES_RIGHT = torch.tensor(
[
0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019,
0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096,
]
)
# fmt: on
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (1, 2, 5760))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
@slow
def test_generate_text_audio_prompt(self):
model = self.model
processor = self.processor
# create stereo inputs
audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)]
text = ["80s music", "Club techno"]
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
inputs = place_dict_on_device(inputs, device=torch_device)
output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12)
# fmt: off
EXPECTED_VALUES_LEFT = torch.tensor(
[
0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728,
-0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430,
]
)
EXPECTED_VALUES_RIGHT = torch.tensor(
[
0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103,
-0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614,
]
)
# fmt: on
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (2, 2, 37760))
# input values take shape 32000 and we generate from there - we check the last (generated) values
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))