Update Bark generation configs and tests (#25409)

* update bark generation configs for more coherent parameter

* make style

* update bark hub repo
This commit is contained in:
Yoach Lacombe 2023-08-09 18:28:02 +02:00 committed by GitHub
parent cf84738d2e
commit 704bf595eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 20 deletions

View File

@ -36,8 +36,8 @@ class BarkSemanticGenerationConfig(GenerationConfig):
return_dict_in_generate=False,
output_hidden_states=False,
output_attentions=False,
temperature=0.7,
do_sample=True,
temperature=1.0,
do_sample=False,
text_encoding_offset=10_048,
text_pad_token=129_595,
semantic_infer_token=129_599,
@ -70,9 +70,9 @@ class BarkSemanticGenerationConfig(GenerationConfig):
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
temperature (`float`, *optional*, defaults to 0.7):
temperature (`float`, *optional*, defaults to 1.0):
The value used to modulate the next token probabilities.
do_sample (`bool`, *optional*, defaults to `True`):
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
text_encoding_offset (`int`, *optional*, defaults to 10_048):
Text encoding offset.
@ -119,8 +119,8 @@ class BarkCoarseGenerationConfig(GenerationConfig):
return_dict_in_generate=False,
output_hidden_states=False,
output_attentions=False,
temperature=0.7,
do_sample=True,
temperature=1.0,
do_sample=False,
coarse_semantic_pad_token=12_048,
coarse_rate_hz=75,
n_coarse_codebooks=2,
@ -150,9 +150,9 @@ class BarkCoarseGenerationConfig(GenerationConfig):
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
temperature (`float`, *optional*, defaults to 0.7):
temperature (`float`, *optional*, defaults to 1.0):
The value used to modulate the next token probabilities.
do_sample (`bool`, *optional*, defaults to `True`):
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
coarse_semantic_pad_token (`int`, *optional*, defaults to 12_048):
Coarse semantic pad token.
@ -194,7 +194,7 @@ class BarkFineGenerationConfig(GenerationConfig):
def __init__(
self,
temperature=0.5,
temperature=1.0,
max_fine_history_length=512,
max_fine_input_length=1024,
n_fine_codebooks=8,
@ -209,7 +209,7 @@ class BarkFineGenerationConfig(GenerationConfig):
documentation from [`GenerationConfig`] for more information.
Args:
temperature (`float`, *optional*, defaults to 0.5):
temperature (`float`, *optional*):
The value used to modulate the next token probabilities.
max_fine_history_length (`int`, *optional*, defaults to 512):
Max length of the fine history vector.
@ -224,6 +224,13 @@ class BarkFineGenerationConfig(GenerationConfig):
self.max_fine_input_length = max_fine_input_length
self.n_fine_codebooks = n_fine_codebooks
def validate(self):
"""
Overrides GenerationConfig.validate because BarkFineGenerationConfig don't use any parameters outside
temperature.
"""
pass
class BarkGenerationConfig(GenerationConfig):
model_type = "bark"

View File

@ -1336,7 +1336,7 @@ class BarkFineModel(BarkPreTrainedModel):
input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
logits = self.forward(n_inner, input_buffer).logits
if temperature is None:
if temperature is None or temperature == 1.0:
relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
codebook_preds = torch.argmax(relevant_logits, -1)
else:
@ -1499,8 +1499,8 @@ class BarkModel(BarkPreTrainedModel):
```python
>>> from transformers import AutoProcessor, BarkModel
>>> processor = AutoProcessor.from_pretrained("ylacombe/bark-small")
>>> model = BarkModel.from_pretrained("ylacombe/bark-small")
>>> processor = AutoProcessor.from_pretrained("suno/bark-small")
>>> model = BarkModel.from_pretrained("suno/bark-small")
>>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
>>> voice_preset = "v2/en_speaker_6"

View File

@ -894,11 +894,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
class BarkModelIntegrationTests(unittest.TestCase):
@cached_property
def model(self):
return BarkModel.from_pretrained("ylacombe/bark-large").to(torch_device)
return BarkModel.from_pretrained("suno/bark").to(torch_device)
@cached_property
def processor(self):
return BarkProcessor.from_pretrained("ylacombe/bark-large")
return BarkProcessor.from_pretrained("suno/bark")
@cached_property
def inputs(self):
@ -937,6 +937,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)
@ -957,6 +958,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)
@ -964,6 +966,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
output_ids,
history_prompt=history_prompt,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
coarse_generation_config=self.coarse_generation_config,
codebook_size=self.model.generation_config.codebook_size,
@ -994,6 +997,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)
@ -1001,6 +1005,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
output_ids,
history_prompt=history_prompt,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
coarse_generation_config=self.coarse_generation_config,
codebook_size=self.model.generation_config.codebook_size,
@ -1040,9 +1045,16 @@ class BarkModelIntegrationTests(unittest.TestCase):
input_ids = self.inputs
with torch.no_grad():
self.model.generate(**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7)
self.model.generate(
**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
)
self.model.generate(
**input_ids,
do_sample=False,
temperature=1.0,
coarse_do_sample=True,
coarse_temperature=0.7,
fine_temperature=0.3,
)
self.model.generate(
**input_ids,
@ -1061,7 +1073,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
with torch.no_grad():
# standard generation
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
torch.cuda.empty_cache()
@ -1088,7 +1100,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))
# output with cpu offload
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
output_with_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
# checks if same output
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())

View File

@ -26,7 +26,7 @@ from transformers.testing_utils import require_torch, slow
@require_torch
class BarkProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "ylacombe/bark-small"
self.checkpoint = "suno/bark-small"
self.tmpdirname = tempfile.mkdtemp()
self.voice_preset = "en_speaker_1"
self.input_string = "This is a test string"