transformers/docs/source/en/model_doc/musicgen.md
Sanchit Gandhi 1c1c90756d
Add Musicgen (#24109)
* Add Audiocraft

* add cross attention

* style

* add for lm

* convert and verify

* introduce t5

* split configs

* load t5 + lm

* clean conversion

* copy from t5

* style

* start pattern provider

* make generation work

* style

* fix pos embs

* propagate shape changes

* propagate shape changes

* style

* delay pattern: pad tokens at end

* audiocraft -> musicgen

* fix inits

* add mdx

* style

* fix pad token in processor

* override generate and add todos

* add init to test

* undo pattern delay mask after gen

* remove cfg logits processor

* remove cfg logits processor

* remove logits processor in favour of mask

* clean pos embs

* make fix copies

* update readmes

* clean pos emb

* refactor encoder/decoder

* make fix copies

* update conversion

* fix config imports

* update config docs

* make style

* send pattern mask to device

* pattern mask with delay

* recover prompted audio tokens

* fix docstrings

* laydown test file

* pattern edge case

* remove t5 ref

* add processing class

* config refactor

* better pattern comment

* check if mask is not present

* check if mask is not present

* refactor to auto class

* remove encoder configs

* fix processor

* processor import

* start updating conversion

* start updating tests

* make style

* convert t5, encodec, lm

* convert as composite

* also convert processor

* run generate

* classifier free gen

* comments and clean up

* make style

* docs for logit proc

* docstring for uncond gen

* start lm tests

* work tests

* let the lm generate

* refactor: reshape inside forward

* undo greedy loop changes

* from_enc_dec -> from_sub_model

* fix input id shapes in docstrings

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* undo generate changes

* from sub model config

* Update src/transformers/models/musicgen/modeling_musicgen.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make generate work again

* generate uncond -> get uncond inputs

* remove prefix allowed tokens fn

* better error message

* logit proc checks

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* make decoder only tests work

* composite fast tests

* make style

* uncond generation

* feat extr padding

* make audio prompt work

* fix inputs docstrings

* unconditional inputs: dict -> model output

* clean up tests

* more clean up tests

* make style

* t5 encoder -> auto text encoder

* remove comments

* deal with frames

* fix auto text

* slow tests

* nice mdx

* remove can generate

* todo - hub id

* convert m/l

* make fix copies

* only import generation with torch

* ignore decoder from tests

* don't wrap uncond inputs

* make style

* cleaner uncond inputs

* add example to musicgen forward

* fix docs

* ignore MusicGen Model/ForConditionalGeneration in auto mapping

* add doc section to toctree

* add to doc tests

* add processor tests

* fix push to hub in conversion

* tips for decoder only loading

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix conversion for s / m / l checkpoints

* import stopping criteria from module

* remove from pipeline tests

* fix uncond docstring

* decode audio method

* fix docs

* org: sanchit-gandhi -> facebook

* fix max pos embeddings

* remove auto doc (not compatible with shapes)

* bump max pos emb

* make style

* fix doc

* fix config doc

* fix config doc

* ignore musicgen config from docstring

* make style

* fix config

* fix config for doctest

* consistent from_sub_models

* don't automap decoder

* fix mdx save audio file

* fix mdx save audio file

* processor batch decode for audio

* remove keys to ignore

* update doc md

* update generation config

* allow changes for default generation config

* update tests

* make style

* fix docstring for uncond

* fix processor test

* fix processor test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2023-06-29 14:48:59 +01:00

12 KiB

MusicGen

Overview

The MusicGen model was proposed in the paper Simple and Controllable Music Generation by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.

MusicGen is a single stage auto-regressive Transformer model capable of generating high-quality music samples conditioned on text descriptions or audio prompts. The text descriptions are passed through a frozen text encoder model to obtain a sequence of hidden-state representations. MusicGen is then trained to predict discrete audio tokens, or audio codes, conditioned on these hidden-states. These audio tokens are then decoded using an audio compression model, such as EnCodec, to recover the audio waveform.

Through an efficient token interleaving pattern, MusicGen does not require a self-supervised semantic representation of the text/audio prompts, thus eliminating the need to cascade multiple models to predict a set of codebooks (e.g. hierarchically or upsampling). Instead, it is able to generate all the codebooks in a single forward pass.

The abstract from the paper is the following:

We tackle the task of conditional music generation. We introduce MusicGen, a single Language Model (LM) that operates over several streams of compressed discrete music representation, i.e., tokens. Unlike prior work, MusicGen is comprised of a single-stage transformer LM together with efficient token interleaving patterns, which eliminates the need for cascading several models, e.g., hierarchically or upsampling. Following this approach, we demonstrate how MusicGen can generate high-quality samples, while being conditioned on textual description or melodic features, allowing better controls over the generated output. We conduct extensive empirical evaluation, considering both automatic and human studies, showing the proposed approach is superior to the evaluated baselines on a standard text-to-music benchmark. Through ablation studies, we shed light over the importance of each of the components comprising MusicGen.

This model was contributed by sanchit-gandhi. The original code can be found here. The pre-trained checkpoints can be found on the Hugging Face Hub.

Generation

MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default, and can be explicitly specified by setting do_sample=True in the call to [MusicgenForConditionalGeneration.generate], or by overriding the model's generation config (see below).

Unconditional Generation

The inputs for unconditional (or 'null') generation can be obtained through the method [MusicgenForConditionalGeneration.get_unconditional_inputs]:

>>> from transformers import MusicgenForConditionalGeneration

>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)

>>> audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)

The audio outputs are a three-dimensional Torch tensor of shape (batch_size, num_channels, sequence_length). To listen to the generated audio samples, you can either play them in an ipynb notebook:

from IPython.display import Audio

sampling_rate = model.config.audio_encoder.sampling_rate
Audio(audio_values[0].numpy(), rate=sampling_rate)

Or save them as a .wav file using a third-party library, e.g. scipy:

>>> import scipy

>>> sampling_rate = model.config.audio_encoder.sampling_rate
>>> scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())

Text-Conditional Generation

The model can generate an audio sample conditioned on a text prompt through use of the [MusicgenProcessor] to pre-process the inputs:

>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration

>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

>>> inputs = processor(
...     text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
...     padding=True,
...     return_tensors="pt",
... )
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

The guidance_scale is used in classifier free guidance (CFG), setting the weighting between the conditional logits (which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or 'null' prompt). Higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer audio quality. CFG is enabled by setting guidance_scale > 1. For best results, use guidance_scale=3 (default).

Audio-Prompted Generation

The same [MusicgenProcessor] can be used to pre-process an audio prompt that is used for audio continuation. In the following example, we load an audio file using the 🤗 Datasets library, which can be pip installed through the command below:

pip install --upgrade pip
pip install datasets[audio]
>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
>>> from datasets import load_dataset

>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True)
>>> sample = next(iter(dataset))["audio"]

>>> # take the first half of the audio sample
>>> sample["array"] = sample["array"][: len(sample["array"]) // 2]

>>> inputs = processor(
...     audio=sample["array"],
...     sampling_rate=sample["sampling_rate"],
...     text=["80s blues track with groovy saxophone"],
...     padding=True,
...     return_tensors="pt",
... )
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

For batched audio-prompted generation, the generated audio_values can be post-processed to remove padding by using the [MusicgenProcessor] class:

>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
>>> from datasets import load_dataset

>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True)
>>> sample = next(iter(dataset))["audio"]

>>> # take the first quarter of the audio sample
>>> sample_1 = sample["array"][: len(sample["array"]) // 4]

>>> # take the first half of the audio sample
>>> sample_2 = sample["array"][: len(sample["array"]) // 2]

>>> inputs = processor(
...     audio=[sample_1, sample_2],
...     sampling_rate=sample["sampling_rate"],
...     text=["80s blues track with groovy saxophone", "90s rock song with loud guitars and heavy drums"],
...     padding=True,
...     return_tensors="pt",
... )
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

>>> # post-process to remove padding from the batched audio
>>> audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)

Generation Configuration

The default parameters that control the generation process, such as sampling, guidance scale and number of generated tokens, can be found in the model's generation config, and updated as desired:

>>> from transformers import MusicgenForConditionalGeneration

>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

>>> # inspect the default generation config
>>> model.generation_config

>>> # increase the guidance scale to 4.0
>>> model.generation_config.guidance_scale = 4.0

>>> # decrease the max length to 256 tokens
>>> model.generation_config.max_length = 256

Note that any arguments passed to the generate method will supersede those in the generation config, so setting do_sample=False in the call to generate will supersede the setting of model.generation_config.do_sample in the generation config.

Model Structure

The MusicGen model can be de-composed into three distinct stages:

  1. Text encoder: maps the text inputs to a sequence of hidden-state representations. The pre-trained MusicGen models use a frozen text encoder from either T5 or Flan-T5
  2. MusicGen decoder: a language model (LM) that auto-regressively generates audio tokens (or codes) conditional on the encoder hidden-state representations
  3. Audio encoder/decoder: used to encode an audio prompt to use as prompt tokens, and recover the audio waveform from the audio tokens predicted by the decoder

Thus, the MusicGen model can either be used as a standalone decoder model, corresponding to the class [MusicgenForCausalLM], or as a composite model that includes the text encoder and audio encoder/decoder, corresponding to the class [MusicgenForConditionalGeneration].

Since the text encoder and audio encoder/decoder models are frozen during training, the MusicGen decoder [MusicgenForCausalLM] can be trained standalone on a dataset of encoder hidden-states and audio codes. For inference, the trained decoder can be combined with the frozen text encoder and audio encoder/decoders to recover the composite [MusicgenForConditionalGeneration] model.

Below, we demonstrate how to construct the composite [MusicgenForConditionalGeneration] model from its three constituent parts, as would typically be done following training of the MusicGen decoder LM:

>>> from transformers import AutoConfig, AutoModelForTextEncoding, AutoModel, MusicgenForCausalLM, MusicgenForConditionalGeneration

>>> text_encoder = AutoModelForTextEncoding.from_pretrained("t5-base")
>>> audio_encoder = AutoModel.from_pretrained("facebook/encodec_32khz")
>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder
>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config)

>>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(text_encoder, audio_encoder, decoder)

If only the decoder needs to be loaded from the pre-trained checkpoint for the composite model, it can be loaded by first specifying the correct config, or be accessed through the .decoder attribute of the composite model:

>>> from transformers import AutoConfig, MusicgenForCausalLM, MusicgenForConditionalGeneration

>>> # Option 1: get decoder config and pass to `.from_pretrained`
>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder
>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config)

>>> # Option 2: load the entire composite model, but only return the decoder
>>> decoder = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").decoder

Tips:

  • MusicGen is trained on the 32kHz checkpoint of Encodec. You should ensure you use a compatible version of the Encodec model.
  • Sampling mode tends to deliver better results than greedy - you can toggle sampling with the variable do_sample in the call to [MusicgenForConditionalGeneration.generate]

MusicgenDecoderConfig

autodoc MusicgenDecoderConfig

MusicgenConfig

autodoc MusicgenConfig

MusicgenProcessor

autodoc MusicgenProcessor

MusicgenModel

autodoc MusicgenModel - forward

MusicgenForCausalLM

autodoc MusicgenForCausalLM - forward

MusicgenForConditionalGeneration

autodoc MusicgenForConditionalGeneration - forward