transformers/tests/models/musicgen/test_processing_musicgen.py
Sanchit Gandhi 1c1c90756d
Add Musicgen (#24109)
* Add Audiocraft

* add cross attention

* style

* add for lm

* convert and verify

* introduce t5

* split configs

* load t5 + lm

* clean conversion

* copy from t5

* style

* start pattern provider

* make generation work

* style

* fix pos embs

* propagate shape changes

* propagate shape changes

* style

* delay pattern: pad tokens at end

* audiocraft -> musicgen

* fix inits

* add mdx

* style

* fix pad token in processor

* override generate and add todos

* add init to test

* undo pattern delay mask after gen

* remove cfg logits processor

* remove cfg logits processor

* remove logits processor in favour of mask

* clean pos embs

* make fix copies

* update readmes

* clean pos emb

* refactor encoder/decoder

* make fix copies

* update conversion

* fix config imports

* update config docs

* make style

* send pattern mask to device

* pattern mask with delay

* recover prompted audio tokens

* fix docstrings

* laydown test file

* pattern edge case

* remove t5 ref

* add processing class

* config refactor

* better pattern comment

* check if mask is not present

* check if mask is not present

* refactor to auto class

* remove encoder configs

* fix processor

* processor import

* start updating conversion

* start updating tests

* make style

* convert t5, encodec, lm

* convert as composite

* also convert processor

* run generate

* classifier free gen

* comments and clean up

* make style

* docs for logit proc

* docstring for uncond gen

* start lm tests

* work tests

* let the lm generate

* refactor: reshape inside forward

* undo greedy loop changes

* from_enc_dec -> from_sub_model

* fix input id shapes in docstrings

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* undo generate changes

* from sub model config

* Update src/transformers/models/musicgen/modeling_musicgen.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make generate work again

* generate uncond -> get uncond inputs

* remove prefix allowed tokens fn

* better error message

* logit proc checks

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* make decoder only tests work

* composite fast tests

* make style

* uncond generation

* feat extr padding

* make audio prompt work

* fix inputs docstrings

* unconditional inputs: dict -> model output

* clean up tests

* more clean up tests

* make style

* t5 encoder -> auto text encoder

* remove comments

* deal with frames

* fix auto text

* slow tests

* nice mdx

* remove can generate

* todo - hub id

* convert m/l

* make fix copies

* only import generation with torch

* ignore decoder from tests

* don't wrap uncond inputs

* make style

* cleaner uncond inputs

* add example to musicgen forward

* fix docs

* ignore MusicGen Model/ForConditionalGeneration in auto mapping

* add doc section to toctree

* add to doc tests

* add processor tests

* fix push to hub in conversion

* tips for decoder only loading

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix conversion for s / m / l checkpoints

* import stopping criteria from module

* remove from pipeline tests

* fix uncond docstring

* decode audio method

* fix docs

* org: sanchit-gandhi -> facebook

* fix max pos embeddings

* remove auto doc (not compatible with shapes)

* bump max pos emb

* make style

* fix doc

* fix config doc

* fix config doc

* ignore musicgen config from docstring

* make style

* fix config

* fix config for doctest

* consistent from_sub_models

* don't automap decoder

* fix mdx save audio file

* fix mdx save audio file

* processor batch decode for audio

* remove keys to ignore

* update doc md

* update generation config

* allow changes for default generation config

* update tests

* make style

* fix docstring for uncond

* fix processor test

* fix processor test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2023-06-29 14:48:59 +01:00

174 lines
6.4 KiB
Python

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the MusicGen processor."""
import random
import shutil
import tempfile
import unittest
import numpy as np
from transformers import T5Tokenizer, T5TokenizerFast
from transformers.testing_utils import require_sentencepiece, require_torch
from transformers.utils.import_utils import is_speech_available, is_torch_available
if is_torch_available():
pass
if is_speech_available():
from transformers import EncodecFeatureExtractor, MusicgenProcessor
global_rng = random.Random()
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_sentencepiece
class MusicgenProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "facebook/musicgen-small"
self.tmpdirname = tempfile.mkdtemp()
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return EncodecFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
processor.save_pretrained(self.tmpdirname)
processor = MusicgenProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
def test_save_load_pretrained_additional_features(self):
processor = MusicgenProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
processor = MusicgenProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(sequences=predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
def test_decode_audio(self):
feature_extractor = self.get_feature_extractor(padding_side="left")
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = [floats_list((1, x))[0] for x in range(5, 20, 5)]
padding_mask = processor(raw_speech).padding_mask
generated_speech = np.asarray(floats_list((3, 20)))[:, None, :]
decoded_audios = processor.batch_decode(generated_speech, padding_mask=padding_mask)
self.assertIsInstance(decoded_audios, list)
for audio in decoded_audios:
self.assertIsInstance(audio, np.ndarray)
self.assertTrue(decoded_audios[0].shape == (1, 10))
self.assertTrue(decoded_audios[1].shape == (1, 15))
self.assertTrue(decoded_audios[2].shape == (1, 20))