Make bark could have tiny model (#25290)

* temp

* update

* update

* update

* small dim

* small dim

* small dim

* fix

* update

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-08-04 15:13:14 +02:00 committed by GitHub
parent f0fd73a2de
commit ce6d153a53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 26 deletions

View File

@ -19,7 +19,7 @@ from typing import Dict, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import add_start_docstrings, logging
from ..auto import AutoConfig
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
@ -299,7 +299,8 @@ class BarkConfig(PretrainedConfig):
self.semantic_config = BarkSemanticConfig(**semantic_config)
self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config)
self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config)
self.codec_config = AutoConfig.for_model(**codec_config)
codec_model_type = codec_config["model_type"] if "model_type" in codec_config else "encodec"
self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config)
self.initializer_range = initializer_range
@ -311,7 +312,7 @@ class BarkConfig(PretrainedConfig):
semantic_config: BarkSemanticConfig,
coarse_acoustics_config: BarkCoarseConfig,
fine_acoustics_config: BarkFineConfig,
codec_config: AutoConfig,
codec_config: PretrainedConfig,
**kwargs,
):
r"""

View File

@ -22,6 +22,7 @@ import unittest
from transformers import (
BarkCoarseConfig,
BarkConfig,
BarkFineConfig,
BarkSemanticConfig,
is_torch_available,
@ -37,6 +38,7 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ..encodec.test_modeling_encodec import EncodecModelTester
if is_torch_available():
@ -72,8 +74,6 @@ class BarkSemanticModelTester:
initializer_range=0.02,
n_codes_total=8, # for BarkFineModel
n_codes_given=1, # for BarkFineModel
config_class=None,
model_class=None,
):
self.parent = parent
self.batch_size = batch_size
@ -98,8 +98,6 @@ class BarkSemanticModelTester:
self.n_codes_given = n_codes_given
self.is_encoder_decoder = False
self.config_class = config_class
self.model_class = model_class
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -121,7 +119,7 @@ class BarkSemanticModelTester:
return config, inputs_dict
def get_config(self):
return self.config_class(
return BarkSemanticConfig(
vocab_size=self.vocab_size,
output_vocab_size=self.output_vocab_size,
hidden_size=self.hidden_size,
@ -137,6 +135,7 @@ class BarkSemanticModelTester:
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 300
config.output_vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@ -144,7 +143,7 @@ class BarkSemanticModelTester:
return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = self.model_class(config=config).to(torch_device).eval()
model = BarkSemanticModel(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
@ -211,8 +210,6 @@ class BarkCoarseModelTester:
initializer_range=0.02,
n_codes_total=8, # for BarkFineModel
n_codes_given=1, # for BarkFineModel
config_class=None,
model_class=None,
):
self.parent = parent
self.batch_size = batch_size
@ -237,8 +234,6 @@ class BarkCoarseModelTester:
self.n_codes_given = n_codes_given
self.is_encoder_decoder = False
self.config_class = config_class
self.model_class = model_class
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -260,7 +255,7 @@ class BarkCoarseModelTester:
return config, inputs_dict
def get_config(self):
return self.config_class(
return BarkCoarseConfig(
vocab_size=self.vocab_size,
output_vocab_size=self.output_vocab_size,
hidden_size=self.hidden_size,
@ -276,6 +271,7 @@ class BarkCoarseModelTester:
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 300
config.output_vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@ -283,7 +279,7 @@ class BarkCoarseModelTester:
return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = self.model_class(config=config).to(torch_device).eval()
model = BarkCoarseModel(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
@ -350,8 +346,6 @@ class BarkFineModelTester:
initializer_range=0.02,
n_codes_total=8, # for BarkFineModel
n_codes_given=1, # for BarkFineModel
config_class=None,
model_class=None,
):
self.parent = parent
self.batch_size = batch_size
@ -376,8 +370,6 @@ class BarkFineModelTester:
self.n_codes_given = n_codes_given
self.is_encoder_decoder = False
self.config_class = config_class
self.model_class = model_class
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length, self.n_codes_total], self.vocab_size)
@ -403,7 +395,7 @@ class BarkFineModelTester:
return config, inputs_dict
def get_config(self):
return self.config_class(
return BarkFineConfig(
vocab_size=self.vocab_size,
output_vocab_size=self.output_vocab_size,
hidden_size=self.hidden_size,
@ -419,6 +411,7 @@ class BarkFineModelTester:
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 300
config.output_vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@ -426,7 +419,7 @@ class BarkFineModelTester:
return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = self.model_class(config=config).to(torch_device).eval()
model = BarkFineModel(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
@ -473,6 +466,79 @@ class BarkFineModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
class BarkModelTester:
def __init__(
self,
parent,
semantic_kwargs=None,
coarse_acoustics_kwargs=None,
fine_acoustics_kwargs=None,
codec_kwargs=None,
is_training=False, # for now training is not supported
):
if semantic_kwargs is None:
semantic_kwargs = {}
if coarse_acoustics_kwargs is None:
coarse_acoustics_kwargs = {}
if fine_acoustics_kwargs is None:
fine_acoustics_kwargs = {}
if codec_kwargs is None:
codec_kwargs = {}
self.parent = parent
self.semantic_model_tester = BarkSemanticModelTester(parent, **semantic_kwargs)
self.coarse_acoustics_model_tester = BarkCoarseModelTester(parent, **coarse_acoustics_kwargs)
self.fine_acoustics_model_tester = BarkFineModelTester(parent, **fine_acoustics_kwargs)
self.codec_model_tester = EncodecModelTester(parent, **codec_kwargs)
self.is_training = is_training
def prepare_config_and_inputs(self):
# TODO: @Yoach: Preapre `inputs_dict`
inputs_dict = {}
config = self.get_config()
return config, inputs_dict
def get_config(self):
return BarkConfig.from_sub_model_configs(
self.semantic_model_tester.get_config(),
self.coarse_acoustics_model_tester.get_config(),
self.fine_acoustics_model_tester.get_config(),
self.codec_model_tester.get_config(),
)
def get_pipeline_config(self):
config = self.get_config()
# follow the `get_pipeline_config` of the sub component models
config.semantic_config.vocab_size = 300
config.coarse_acoustics_config.vocab_size = 300
config.fine_acoustics_config.vocab_size = 300
config.semantic_config.output_vocab_size = 300
config.coarse_acoustics_config.output_vocab_size = 300
config.fine_acoustics_config.output_vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
# TODO: @Yoach
pass
# return config, inputs_dict
# Need this class in oder to create tiny model for `bark`
# TODO (@Yoach) Implement actual test methods
@unittest.skip("So far all tests will fail.")
class BarkModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BarkModel,) if is_torch_available() else ()
def setUp(self):
self.model_tester = BarkModelTester(self)
self.config_tester = ConfigTester(self, config_class=BarkConfig, n_embd=37)
@require_torch
class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BarkSemanticModel,) if is_torch_available() else ()
@ -488,9 +554,7 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
test_resize_embeddings = True
def setUp(self):
self.model_tester = BarkSemanticModelTester(
self, config_class=BarkSemanticConfig, model_class=BarkSemanticModel
)
self.model_tester = BarkSemanticModelTester(self)
self.config_tester = ConfigTester(self, config_class=BarkSemanticConfig, n_embd=37)
def test_config(self):
@ -556,7 +620,7 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
test_resize_embeddings = True
def setUp(self):
self.model_tester = BarkCoarseModelTester(self, config_class=BarkCoarseConfig, model_class=BarkCoarseModel)
self.model_tester = BarkCoarseModelTester(self)
self.config_tester = ConfigTester(self, config_class=BarkCoarseConfig, n_embd=37)
def test_config(self):
@ -623,7 +687,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
test_resize_embeddings = True
def setUp(self):
self.model_tester = BarkFineModelTester(self, config_class=BarkFineConfig, model_class=BarkFineModel)
self.model_tester = BarkFineModelTester(self)
self.config_tester = ConfigTester(self, config_class=BarkFineConfig, n_embd=37)
def test_config(self):

View File

@ -974,6 +974,10 @@ def get_token_id_from_tokenizer(token_id_name, tokenizer, original_token_id):
def get_config_overrides(config_class, processors):
# `Bark` configuration is too special. Let's just not handle this for now.
if config_class.__name__ == "BarkConfig":
return {}
config_overrides = {}
# Check if there is any tokenizer (prefer fast version if any)