mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00

* first raw version of the bark integration * working code on small models with single run * add converting script from suno weights 2 hf * many changes * correct past_kv output * working implementation for inference * update the converting script according to the architecture changes * add a working end-to-end inference code * remove some comments and make small changes * remove unecessary comment * add docstrings and ensure no unecessary intermediary output during audio generation * remove done TODOs * make style + add config docstrings * modification for batch inference support on the whole model * add details to .generation_audio method * add copyright * convert EncodecModel from original library to transformers implementation * add two class in order to facilitate model and sub-models loading from the hub * add support of loading the whole model * add BarkProcessor * correct modeling according to processor output * Add proper __init__ and auto support * Add up-to-date copyright/license message * add relative import instead of absolute * cleaner head_dim computation * small comment removal or changes * more verbose LayerNorm init method * specify eps for clearer comprehension * more verbose variable naming in the MLP module * remove unecessary BarkBlock parameter * clearer code in the forward pass of the BarkBlock * remove _initialize_modules method for cleaner code * Remove unnecessary methods from sub-models * move code to remove unnecessary function * rename a variable for clarity and change an assert * move code and change variable name for clarity * remove unnecessary asserts * correct small bug * correct a comment * change variable names for clarity * remove asserts * change import from absolute to relative * correct small error due to comma missing + correct import * Add attribute Bark config * add first version of tests * update attention_map * add tie_weights and resize_token_embeddings for fineModel * correct getting attention_mask in generate_text_semantic * remove Bark inference trick * leave more choices in barkProcessor * remove _no_split_modules * fixe error in forward of block and introduce clearer notations * correct converting script with last changes * make style + add draft bark.mdx * correct BarkModelTest::test_generate_text_semantic * add Bark in main README * add dummy_pt_objects for Bark * add missing models in the main init * correct test_decoder_model_past_with_large_inputs * disable torchscript test * change docstring of BarkProcessor * Add test_processor_bark * make style * correct copyrights * add bark.mdx + make style, quality and consistency * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Remove unnecessary test method * simply logic of a test * Only check first ids for slow audio generation * split full end-to-end generation tests * remove unneccessary comment * change submodel names for clearer naming * remove ModuleDict from modeling_bark * combine two if statements * ensure that an edge misued won't happen * modify variable name * move code snippet to the right place (coarse instead of semantic) * change BarkSemanticModule -> BarkSemanticModel * align BarkProcessor with transformers paradigm * correct BarkProcessor tests with last commit changes * change _validate_voice_preset to an instance method instead of a class method * tie_weights already called with post_init * add codec_model config to configuration * update bark modeling tests with recent BarkProcessor changes * remove SubModelPretrainedModel + change speakers embeddings prompt type in BarkModel * change absolute imports to relative * remove TODO * change docstrings * add examples to docs and docstrings * make style * uses BatchFeature in BarkProcessor insteads of dict * continue improving docstrings and docs + make style * correct docstrings examples * more comprehensible speaker_embeddings load/Save * rename speaker_embeddings_dict -> speaker_embeddings * correct bark.mdx + add bark to documentation_tests * correct docstrings configuration_bark * integrate last nit suggestions * integrate BarkGeneration configs * make style * remove bark tests from documentation_tests.txt because timeout - tested manually * add proper generation config initialization * small bark.mdx documentation changes * rename bark.mdx -> bark.md * add torch.no_grad behind BarkModel.generate_audio() * replace assert by ValueError in convert_suno_to_hf.py * integrate a series of short comments from reviewer * move SemanticLogitsProcessors and remove .detach() from Bark docs and docstrings * actually remove SemanticLogitsProcessor from modeling_bark.oy * BarkProcessor returns a single output instead of tuple + correct docstrings * make style + correct bug * add initializer_range to BarkConfig + correct slow modeling tests * add .clone() to history_prompt.coarse_prompt to avoid modifying input array * Making sure no extra "`" are present * remove extra characters in modeling_bark.py * Correct output if history_prompt is None * remove TODOs * remove ravel comment * completing generation_configuration_bark.py docstrings * change docstrings - number of audio codebooks instead of Encodec codebooks * change 'bias' docstrings in configuration_bark.py * format code * rename BarkModel.generate_audio -> BarkModel.generate_speech * modify AutoConfig instead of EncodecConfig in BarkConfig * correct AutoConfig wrong init * refactor BarkModel and sub-models generate_coarse, generate_fine, generate_text_semantic * remove SemanticLogitsProcessor and replace it with SuppressTokensLogitsProcessor * move nb_codebook related config arguments to BarkFineConfig * rename bark.mdx -> bark.md * correcting BarkModelConfig from_pretrained + remove keys_to_ignore * correct bark.md with correct hub path * correct code bug in bark.md * correct list tokens_to_suppress * modify Processor to load nested speaker embeddings in a safer way * correct batch sampling in BarkFineModel.generate_fine * Apply suggestions from code review Small docstrings correction and code improvements Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * give more details about num_layers in docstrings * correct indentation mistake * correct submodelconfig order of docstring variables * put audio models in alphabetical order in utils/check_repo.my * remove useless line from test_modeling_bark.py * makes BarkCoarseModelTest inherits from (ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) instead of BarkSemanticModelTest * make a Tester class for each sub-model instead of inheriting * add test_resize_embeddings=True for Bark sub-models * add Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads * remove 'Copied fom Bark' comment * remove unneccessary comment * change np.min -> min in modeling_bark.py * refactored all custom layers to have Bark prefix * add attention_mask as an argument of generate_text_semantic * refactor sub-models start docstrings to have more precise config class definition * move _tied_weights_keys overriding * add docstrings to generate_xxx in modeling_bark.py * add loading whole BarkModel to convert_suno_to_hf * refactor attribute and variable names * make style convert_suno * update bark checkpoints * remove never entered if statement * move bark_modeling docstrings after BarkPretrainedModel class definition * refactor modeling_bark.py: kv -> key_values * small nits - code refactoring and removing unecessary lines from _init_weights * nits - replace inplace method by variable assigning * remove *optional* when necessary * remove some lines in generate_speech * add default value for optional parameter * Refactor preprocess_histories_before_coarse -> preprocess_histories Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct usage after refactoring * refactor Bark's generate_xxx -> generate and modify docstrings and tests accordingly * update docstrings python in configuration_bark.py * add bark files in utils/documentation_test.txt * correct docstrings python snippet * add the ability to use parameters in the form of e.g coarse_temperature * add semantic_max_new_tokens in python snippet in docstrings for quicker generation * Reformate sub-models kwargs in BakModel.generate Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * correct kwargs in BarkModel.generate * correct attention_mask kwarg in BarkModel.generate * add tests for sub-models args in BarkModel.generate and correct BarkFineModel.test_generate_fp16 * enrich BarkModel.generate docstrings with a description of how to use the kwargs --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
992 lines
38 KiB
Python
992 lines
38 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Testing suite for the PyTorch Bark model. """
|
|
|
|
|
|
import copy
|
|
import inspect
|
|
import tempfile
|
|
import unittest
|
|
|
|
from transformers import (
|
|
BarkCoarseConfig,
|
|
BarkFineConfig,
|
|
BarkSemanticConfig,
|
|
is_torch_available,
|
|
)
|
|
from transformers.models.bark.generation_configuration_bark import (
|
|
BarkCoarseGenerationConfig,
|
|
BarkFineGenerationConfig,
|
|
BarkSemanticGenerationConfig,
|
|
)
|
|
from transformers.testing_utils import require_torch, slow, torch_device
|
|
from transformers.utils import cached_property
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
BarkCausalModel,
|
|
BarkCoarseModel,
|
|
BarkFineModel,
|
|
BarkModel,
|
|
BarkProcessor,
|
|
BarkSemanticModel,
|
|
)
|
|
|
|
|
|
class BarkSemanticModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
seq_length=4,
|
|
is_training=False, # for now training is not supported
|
|
use_input_mask=True,
|
|
use_labels=True,
|
|
vocab_size=33,
|
|
output_vocab_size=33,
|
|
hidden_size=16,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
intermediate_size=15,
|
|
dropout=0.1,
|
|
window_size=256,
|
|
initializer_range=0.02,
|
|
n_codes_total=8, # for BarkFineModel
|
|
n_codes_given=1, # for BarkFineModel
|
|
config_class=None,
|
|
model_class=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.output_vocab_size = output_vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.dropout = dropout
|
|
self.window_size = window_size
|
|
self.initializer_range = initializer_range
|
|
self.bos_token_id = output_vocab_size - 1
|
|
self.eos_token_id = output_vocab_size - 1
|
|
self.pad_token_id = output_vocab_size - 1
|
|
|
|
self.n_codes_total = n_codes_total
|
|
self.n_codes_given = n_codes_given
|
|
|
|
self.is_encoder_decoder = False
|
|
self.config_class = config_class
|
|
self.model_class = model_class
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
config = self.get_config()
|
|
|
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
|
|
|
inputs_dict = {
|
|
"input_ids": input_ids,
|
|
"head_mask": head_mask,
|
|
"attention_mask": input_mask,
|
|
}
|
|
|
|
return config, inputs_dict
|
|
|
|
def get_config(self):
|
|
return self.config_class(
|
|
vocab_size=self.vocab_size,
|
|
output_vocab_size=self.output_vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_hidden_layers,
|
|
num_heads=self.num_attention_heads,
|
|
use_cache=True,
|
|
bos_token_id=self.bos_token_id,
|
|
eos_token_id=self.eos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
window_size=self.window_size,
|
|
)
|
|
|
|
def get_pipeline_config(self):
|
|
config = self.get_config()
|
|
config.vocab_size = 300
|
|
return config
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
return config, inputs_dict
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
|
model = self.model_class(config=config).to(torch_device).eval()
|
|
|
|
input_ids = inputs_dict["input_ids"]
|
|
attention_mask = inputs_dict["attention_mask"]
|
|
|
|
# first forward pass
|
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
|
|
|
output, past_key_values = outputs.to_tuple()
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
|
|
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["logits"]
|
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
|
"logits"
|
|
]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
# test no attention_mask works
|
|
outputs = model(input_ids, use_cache=True)
|
|
_, past_key_values = outputs.to_tuple()
|
|
output_from_no_past = model(next_input_ids)["logits"]
|
|
|
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["logits"]
|
|
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
|
|
class BarkCoarseModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
seq_length=4,
|
|
is_training=False, # for now training is not supported
|
|
use_input_mask=True,
|
|
use_labels=True,
|
|
vocab_size=33,
|
|
output_vocab_size=33,
|
|
hidden_size=16,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
intermediate_size=15,
|
|
dropout=0.1,
|
|
window_size=256,
|
|
initializer_range=0.02,
|
|
n_codes_total=8, # for BarkFineModel
|
|
n_codes_given=1, # for BarkFineModel
|
|
config_class=None,
|
|
model_class=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.output_vocab_size = output_vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.dropout = dropout
|
|
self.window_size = window_size
|
|
self.initializer_range = initializer_range
|
|
self.bos_token_id = output_vocab_size - 1
|
|
self.eos_token_id = output_vocab_size - 1
|
|
self.pad_token_id = output_vocab_size - 1
|
|
|
|
self.n_codes_total = n_codes_total
|
|
self.n_codes_given = n_codes_given
|
|
|
|
self.is_encoder_decoder = False
|
|
self.config_class = config_class
|
|
self.model_class = model_class
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
config = self.get_config()
|
|
|
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
|
|
|
inputs_dict = {
|
|
"input_ids": input_ids,
|
|
"head_mask": head_mask,
|
|
"attention_mask": input_mask,
|
|
}
|
|
|
|
return config, inputs_dict
|
|
|
|
def get_config(self):
|
|
return self.config_class(
|
|
vocab_size=self.vocab_size,
|
|
output_vocab_size=self.output_vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_hidden_layers,
|
|
num_heads=self.num_attention_heads,
|
|
use_cache=True,
|
|
bos_token_id=self.bos_token_id,
|
|
eos_token_id=self.eos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
window_size=self.window_size,
|
|
)
|
|
|
|
def get_pipeline_config(self):
|
|
config = self.get_config()
|
|
config.vocab_size = 300
|
|
return config
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
return config, inputs_dict
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
|
model = self.model_class(config=config).to(torch_device).eval()
|
|
|
|
input_ids = inputs_dict["input_ids"]
|
|
attention_mask = inputs_dict["attention_mask"]
|
|
|
|
# first forward pass
|
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
|
|
|
output, past_key_values = outputs.to_tuple()
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
|
|
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["logits"]
|
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
|
"logits"
|
|
]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
# test no attention_mask works
|
|
outputs = model(input_ids, use_cache=True)
|
|
_, past_key_values = outputs.to_tuple()
|
|
output_from_no_past = model(next_input_ids)["logits"]
|
|
|
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["logits"]
|
|
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
|
|
class BarkFineModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
seq_length=4,
|
|
is_training=False, # for now training is not supported
|
|
use_input_mask=True,
|
|
use_labels=True,
|
|
vocab_size=33,
|
|
output_vocab_size=33,
|
|
hidden_size=16,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
intermediate_size=15,
|
|
dropout=0.1,
|
|
window_size=256,
|
|
initializer_range=0.02,
|
|
n_codes_total=8, # for BarkFineModel
|
|
n_codes_given=1, # for BarkFineModel
|
|
config_class=None,
|
|
model_class=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.output_vocab_size = output_vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.dropout = dropout
|
|
self.window_size = window_size
|
|
self.initializer_range = initializer_range
|
|
self.bos_token_id = output_vocab_size - 1
|
|
self.eos_token_id = output_vocab_size - 1
|
|
self.pad_token_id = output_vocab_size - 1
|
|
|
|
self.n_codes_total = n_codes_total
|
|
self.n_codes_given = n_codes_given
|
|
|
|
self.is_encoder_decoder = False
|
|
self.config_class = config_class
|
|
self.model_class = model_class
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length, self.n_codes_total], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
config = self.get_config()
|
|
|
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
|
|
|
# randint between self.n_codes_given - 1 and self.n_codes_total - 1
|
|
codebook_idx = ids_tensor((1,), self.n_codes_total - self.n_codes_given).item() + self.n_codes_given
|
|
|
|
inputs_dict = {
|
|
"codebook_idx": codebook_idx,
|
|
"input_ids": input_ids,
|
|
"head_mask": head_mask,
|
|
"attention_mask": input_mask,
|
|
}
|
|
|
|
return config, inputs_dict
|
|
|
|
def get_config(self):
|
|
return self.config_class(
|
|
vocab_size=self.vocab_size,
|
|
output_vocab_size=self.output_vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_hidden_layers,
|
|
num_heads=self.num_attention_heads,
|
|
use_cache=True,
|
|
bos_token_id=self.bos_token_id,
|
|
eos_token_id=self.eos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
window_size=self.window_size,
|
|
)
|
|
|
|
def get_pipeline_config(self):
|
|
config = self.get_config()
|
|
config.vocab_size = 300
|
|
return config
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
return config, inputs_dict
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
|
model = self.model_class(config=config).to(torch_device).eval()
|
|
|
|
input_ids = inputs_dict["input_ids"]
|
|
attention_mask = inputs_dict["attention_mask"]
|
|
|
|
# first forward pass
|
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
|
|
|
output, past_key_values = outputs.to_tuple()
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
|
|
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["logits"]
|
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
|
"logits"
|
|
]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
# test no attention_mask works
|
|
outputs = model(input_ids, use_cache=True)
|
|
_, past_key_values = outputs.to_tuple()
|
|
output_from_no_past = model(next_input_ids)["logits"]
|
|
|
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["logits"]
|
|
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
|
|
@require_torch
|
|
class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
all_model_classes = (BarkSemanticModel,) if is_torch_available() else ()
|
|
all_generative_model_classes = (BarkCausalModel,) if is_torch_available() else ()
|
|
|
|
is_encoder_decoder = False
|
|
fx_compatible = False
|
|
test_missing_keys = False
|
|
test_pruning = False
|
|
test_model_parallel = False
|
|
# no model_parallel for now
|
|
|
|
test_resize_embeddings = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = BarkSemanticModelTester(
|
|
self, config_class=BarkSemanticConfig, model_class=BarkSemanticModel
|
|
)
|
|
self.config_tester = ConfigTester(self, config_class=BarkSemanticConfig, n_embd=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_save_load_strict(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
|
self.assertEqual(info["missing_keys"], [])
|
|
|
|
def test_decoder_model_past_with_large_inputs(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
|
|
|
def test_inputs_embeds(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
input_ids = inputs["input_ids"]
|
|
del inputs["input_ids"]
|
|
|
|
wte = model.get_input_embeddings()
|
|
inputs["input_embeds"] = wte(input_ids)
|
|
|
|
with torch.no_grad():
|
|
model(**inputs)[0]
|
|
|
|
def test_generate_fp16(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
model = self.all_generative_model_classes[0](config).eval().to(torch_device)
|
|
if torch_device == "cuda":
|
|
model.half()
|
|
model.generate(input_ids, attention_mask=attention_mask)
|
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
|
|
|
|
|
@require_torch
|
|
class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
# Same tester as BarkSemanticModelTest, except for model_class and config_class
|
|
all_model_classes = (BarkCoarseModel,) if is_torch_available() else ()
|
|
all_generative_model_classes = (BarkCausalModel,) if is_torch_available() else ()
|
|
|
|
is_encoder_decoder = False
|
|
fx_compatible = False
|
|
test_missing_keys = False
|
|
test_pruning = False
|
|
test_model_parallel = False
|
|
# no model_parallel for now
|
|
|
|
test_resize_embeddings = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = BarkCoarseModelTester(self, config_class=BarkCoarseConfig, model_class=BarkCoarseModel)
|
|
self.config_tester = ConfigTester(self, config_class=BarkCoarseConfig, n_embd=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_save_load_strict(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
|
self.assertEqual(info["missing_keys"], [])
|
|
|
|
def test_decoder_model_past_with_large_inputs(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
|
|
|
def test_inputs_embeds(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
input_ids = inputs["input_ids"]
|
|
del inputs["input_ids"]
|
|
|
|
wte = model.get_input_embeddings()
|
|
inputs["input_embeds"] = wte(input_ids)
|
|
|
|
with torch.no_grad():
|
|
model(**inputs)[0]
|
|
|
|
def test_generate_fp16(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
model = self.all_generative_model_classes[0](config).eval().to(torch_device)
|
|
if torch_device == "cuda":
|
|
model.half()
|
|
model.generate(input_ids, attention_mask=attention_mask)
|
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
|
|
|
|
|
@require_torch
|
|
class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (BarkFineModel,) if is_torch_available() else ()
|
|
|
|
is_encoder_decoder = False
|
|
fx_compatible = False
|
|
test_missing_keys = False
|
|
test_pruning = False
|
|
# no model_parallel for now
|
|
test_model_parallel = False
|
|
|
|
# torchscript disabled for now because forward with an int
|
|
test_torchscript = False
|
|
|
|
test_resize_embeddings = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = BarkFineModelTester(self, config_class=BarkFineConfig, model_class=BarkFineModel)
|
|
self.config_tester = ConfigTester(self, config_class=BarkFineConfig, n_embd=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_save_load_strict(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
|
self.assertEqual(info["missing_keys"], [])
|
|
|
|
def test_inputs_embeds(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
input_ids = inputs["input_ids"]
|
|
del inputs["input_ids"]
|
|
|
|
wte = model.get_input_embeddings()[inputs_dict["codebook_idx"]]
|
|
|
|
inputs["input_embeds"] = wte(input_ids[:, :, inputs_dict["codebook_idx"]])
|
|
|
|
with torch.no_grad():
|
|
model(**inputs)[0]
|
|
|
|
def test_generate_fp16(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
|
input_ids = input_dict["input_ids"]
|
|
# take first codebook channel
|
|
|
|
model = self.all_model_classes[0](config).eval().to(torch_device)
|
|
if torch_device == "cuda":
|
|
model.half()
|
|
|
|
# toy generation_configs
|
|
semantic_generation_config = BarkSemanticGenerationConfig(semantic_vocab_size=0)
|
|
coarse_generation_config = BarkCoarseGenerationConfig(n_coarse_codebooks=config.n_codes_given)
|
|
fine_generation_config = BarkFineGenerationConfig(
|
|
max_fine_history_length=config.block_size // 2,
|
|
max_fine_input_length=config.block_size,
|
|
n_fine_codebooks=config.n_codes_total,
|
|
)
|
|
codebook_size = config.vocab_size - 1
|
|
|
|
model.generate(
|
|
input_ids,
|
|
history_prompt=None,
|
|
temperature=None,
|
|
semantic_generation_config=semantic_generation_config,
|
|
coarse_generation_config=coarse_generation_config,
|
|
fine_generation_config=fine_generation_config,
|
|
codebook_size=codebook_size,
|
|
)
|
|
|
|
model.generate(
|
|
input_ids,
|
|
history_prompt=None,
|
|
temperature=0.7,
|
|
semantic_generation_config=semantic_generation_config,
|
|
coarse_generation_config=coarse_generation_config,
|
|
fine_generation_config=fine_generation_config,
|
|
codebook_size=codebook_size,
|
|
)
|
|
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.forward)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["codebook_idx", "input_ids"]
|
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
|
|
|
def test_model_common_attributes(self):
|
|
# one embedding layer per codebook
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
self.assertIsInstance(model.get_input_embeddings()[0], (torch.nn.Embedding))
|
|
model.set_input_embeddings(
|
|
torch.nn.ModuleList([torch.nn.Embedding(10, 10) for _ in range(config.n_codes_total)])
|
|
)
|
|
x = model.get_output_embeddings()
|
|
self.assertTrue(x is None or isinstance(x[0], torch.nn.Linear))
|
|
|
|
def test_resize_tokens_embeddings(self):
|
|
# resizing tokens_embeddings of a ModuleList
|
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
if not self.test_resize_embeddings:
|
|
return
|
|
|
|
for model_class in self.all_model_classes:
|
|
config = copy.deepcopy(original_config)
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
|
|
if self.model_tester.is_training is False:
|
|
model.eval()
|
|
|
|
model_vocab_size = config.vocab_size
|
|
# Retrieve the embeddings and clone theme
|
|
model_embed_list = model.resize_token_embeddings(model_vocab_size)
|
|
cloned_embeddings_list = [model_embed.weight.clone() for model_embed in model_embed_list]
|
|
|
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
|
model_embed_list = model.resize_token_embeddings(model_vocab_size + 10)
|
|
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
|
|
|
# Check that it actually resizes the embeddings matrix for each codebook
|
|
for model_embed, cloned_embeddings in zip(model_embed_list, cloned_embeddings_list):
|
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
|
|
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
|
model_embed_list = model.resize_token_embeddings(model_vocab_size - 15)
|
|
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
|
for model_embed, cloned_embeddings in zip(model_embed_list, cloned_embeddings_list):
|
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
|
|
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
|
# Input ids should be clamped to the maximum size of the vocabulary
|
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
|
|
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
|
# only check for the first embedding matrix
|
|
models_equal = True
|
|
for p1, p2 in zip(cloned_embeddings_list[0], model_embed_list[0].weight):
|
|
if p1.data.ne(p2.data).sum() > 0:
|
|
models_equal = False
|
|
|
|
self.assertTrue(models_equal)
|
|
|
|
def test_resize_embeddings_untied(self):
|
|
# resizing tokens_embeddings of a ModuleList
|
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
if not self.test_resize_embeddings:
|
|
return
|
|
|
|
original_config.tie_word_embeddings = False
|
|
|
|
for model_class in self.all_model_classes:
|
|
config = copy.deepcopy(original_config)
|
|
model = model_class(config).to(torch_device)
|
|
|
|
# if no output embeddings -> leave test
|
|
if model.get_output_embeddings() is None:
|
|
continue
|
|
|
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
|
model_vocab_size = config.vocab_size
|
|
model.resize_token_embeddings(model_vocab_size + 10)
|
|
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
|
output_embeds_list = model.get_output_embeddings()
|
|
|
|
for output_embeds in output_embeds_list:
|
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
|
|
|
# Check bias if present
|
|
if output_embeds.bias is not None:
|
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
|
|
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
|
model.resize_token_embeddings(model_vocab_size - 15)
|
|
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
|
# Check that it actually resizes the embeddings matrix
|
|
output_embeds_list = model.get_output_embeddings()
|
|
|
|
for output_embeds in output_embeds_list:
|
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
|
# Check bias if present
|
|
if output_embeds.bias is not None:
|
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
|
|
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
|
# Input ids should be clamped to the maximum size of the vocabulary
|
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
|
|
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
|
|
@require_torch
|
|
class BarkModelIntegrationTests(unittest.TestCase):
|
|
@cached_property
|
|
def model(self):
|
|
return BarkModel.from_pretrained("ylacombe/bark-large").to(torch_device)
|
|
|
|
@cached_property
|
|
def processor(self):
|
|
return BarkProcessor.from_pretrained("ylacombe/bark-large")
|
|
|
|
@cached_property
|
|
def inputs(self):
|
|
input_ids = self.processor("In the light of the moon, a little egg lay on a leaf", voice_preset="en_speaker_6")
|
|
|
|
input_ids = input_ids.to(torch_device)
|
|
|
|
return input_ids
|
|
|
|
@cached_property
|
|
def semantic_generation_config(self):
|
|
semantic_generation_config = BarkSemanticGenerationConfig(**self.model.generation_config.semantic_config)
|
|
return semantic_generation_config
|
|
|
|
@cached_property
|
|
def coarse_generation_config(self):
|
|
coarse_generation_config = BarkCoarseGenerationConfig(**self.model.generation_config.coarse_acoustics_config)
|
|
return coarse_generation_config
|
|
|
|
@cached_property
|
|
def fine_generation_config(self):
|
|
fine_generation_config = BarkFineGenerationConfig(**self.model.generation_config.fine_acoustics_config)
|
|
return fine_generation_config
|
|
|
|
@slow
|
|
def test_generate_semantic(self):
|
|
input_ids = self.inputs
|
|
|
|
# fmt: off
|
|
# check first ids
|
|
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
|
|
# fmt: on
|
|
|
|
# greedy decoding
|
|
with torch.no_grad():
|
|
output_ids = self.model.semantic.generate(
|
|
**input_ids,
|
|
do_sample=False,
|
|
semantic_generation_config=self.semantic_generation_config,
|
|
)
|
|
|
|
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
|
|
|
@slow
|
|
def test_generate_coarse(self):
|
|
input_ids = self.inputs
|
|
|
|
history_prompt = input_ids["history_prompt"]
|
|
|
|
# fmt: off
|
|
# check first ids
|
|
expected_output_ids = [11018, 11391, 10651, 11418, 10857, 11620, 10642, 11366, 10312, 11528, 10531, 11516, 10474, 11051, 10524, 11051, ]
|
|
# fmt: on
|
|
|
|
with torch.no_grad():
|
|
output_ids = self.model.semantic.generate(
|
|
**input_ids,
|
|
do_sample=False,
|
|
semantic_generation_config=self.semantic_generation_config,
|
|
)
|
|
|
|
output_ids = self.model.coarse_acoustics.generate(
|
|
output_ids,
|
|
history_prompt=history_prompt,
|
|
do_sample=False,
|
|
semantic_generation_config=self.semantic_generation_config,
|
|
coarse_generation_config=self.coarse_generation_config,
|
|
codebook_size=self.model.generation_config.codebook_size,
|
|
)
|
|
|
|
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
|
|
|
@slow
|
|
def test_generate_fine(self):
|
|
input_ids = self.inputs
|
|
|
|
history_prompt = input_ids["history_prompt"]
|
|
|
|
# fmt: off
|
|
expected_output_ids = [
|
|
[1018, 651, 857, 642, 312, 531, 474, 524, 524, 776,],
|
|
[367, 394, 596, 342, 504, 492, 27, 27, 822, 822,],
|
|
[961, 955, 221, 955, 955, 686, 939, 939, 479, 176,],
|
|
[638, 365, 218, 944, 853, 363, 639, 22, 884, 456,],
|
|
[302, 912, 524, 38, 174, 209, 879, 23, 910, 227,],
|
|
[440, 673, 861, 666, 372, 558, 49, 172, 232, 342,],
|
|
[244, 358, 123, 356, 586, 520, 499, 877, 542, 637,],
|
|
[806, 685, 905, 848, 803, 810, 921, 208, 625, 203,],
|
|
]
|
|
# fmt: on
|
|
|
|
with torch.no_grad():
|
|
output_ids = self.model.semantic.generate(
|
|
**input_ids,
|
|
do_sample=False,
|
|
semantic_generation_config=self.semantic_generation_config,
|
|
)
|
|
|
|
output_ids = self.model.coarse_acoustics.generate(
|
|
output_ids,
|
|
history_prompt=history_prompt,
|
|
do_sample=False,
|
|
semantic_generation_config=self.semantic_generation_config,
|
|
coarse_generation_config=self.coarse_generation_config,
|
|
codebook_size=self.model.generation_config.codebook_size,
|
|
)
|
|
|
|
# greedy decoding
|
|
output_ids = self.model.fine_acoustics.generate(
|
|
output_ids,
|
|
history_prompt=history_prompt,
|
|
temperature=None,
|
|
semantic_generation_config=self.semantic_generation_config,
|
|
coarse_generation_config=self.coarse_generation_config,
|
|
fine_generation_config=self.fine_generation_config,
|
|
codebook_size=self.model.generation_config.codebook_size,
|
|
)
|
|
|
|
self.assertListEqual(output_ids[0, :, : len(expected_output_ids[0])].tolist(), expected_output_ids)
|
|
|
|
@slow
|
|
def test_generate_end_to_end(self):
|
|
input_ids = self.inputs
|
|
|
|
with torch.no_grad():
|
|
self.model.generate(**input_ids)
|
|
self.model.generate(**{key: val for (key, val) in input_ids.items() if key != "history_prompt"})
|
|
|
|
@slow
|
|
def test_generate_end_to_end_with_args(self):
|
|
input_ids = self.inputs
|
|
|
|
with torch.no_grad():
|
|
self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6)
|
|
self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4)
|
|
|
|
@slow
|
|
def test_generate_end_to_end_with_sub_models_args(self):
|
|
input_ids = self.inputs
|
|
|
|
with torch.no_grad():
|
|
self.model.generate(**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7)
|
|
self.model.generate(
|
|
**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3
|
|
)
|
|
self.model.generate(
|
|
**input_ids,
|
|
do_sample=True,
|
|
temperature=0.6,
|
|
penalty_alpha=0.6,
|
|
semantic_temperature=0.9,
|
|
coarse_temperature=0.2,
|
|
fine_temperature=0.1,
|
|
)
|