mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
MusicGen Update (#27084)
* [MusicGen] Add stereo model * safe serialization * Update src/transformers/models/musicgen/modeling_musicgen.py * split over 2 lines * fix slow tests on cuda
This commit is contained in:
parent
5ef650b0ae
commit
f16ff0f07e
@ -57,6 +57,11 @@ Generation is limited by the sinusoidal positional embeddings to 30 second input
|
||||
than 30 seconds of audio (1503 tokens), and input audio passed by Audio-Prompted Generation contributes to this limit so,
|
||||
given an input of 20 seconds of audio, MusicGen cannot generate more than 10 seconds of additional audio.
|
||||
|
||||
Transformers supports both mono (1-channel) and stereo (2-channel) variants of MusicGen. The mono channel versions
|
||||
generate a single set of codebooks. The stereo versions generate 2 sets of codebooks, 1 for each channel (left/right),
|
||||
and each set of codebooks is decoded independently through the audio compression model. The audio streams for each
|
||||
channel are combined to give the final stereo output.
|
||||
|
||||
### Unconditional Generation
|
||||
|
||||
The inputs for unconditional (or 'null') generation can be obtained through the method
|
||||
|
@ -75,6 +75,9 @@ class MusicgenDecoderConfig(PretrainedConfig):
|
||||
The number of parallel codebooks forwarded to the model.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether input and output word embeddings should be tied.
|
||||
audio_channels (`int`, *optional*, defaults to 1
|
||||
Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
|
||||
audio stream for the left/right output channels. Mono models generate a single audio stream output.
|
||||
"""
|
||||
model_type = "musicgen_decoder"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
@ -96,6 +99,7 @@ class MusicgenDecoderConfig(PretrainedConfig):
|
||||
initializer_factor=0.02,
|
||||
scale_embedding=False,
|
||||
num_codebooks=4,
|
||||
audio_channels=1,
|
||||
pad_token_id=2048,
|
||||
bos_token_id=2048,
|
||||
eos_token_id=None,
|
||||
@ -117,6 +121,11 @@ class MusicgenDecoderConfig(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.num_codebooks = num_codebooks
|
||||
|
||||
if audio_channels not in [1, 2]:
|
||||
raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
|
@ -88,32 +88,48 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,
|
||||
|
||||
|
||||
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
|
||||
if checkpoint == "small":
|
||||
if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
|
||||
# default config values
|
||||
hidden_size = 1024
|
||||
num_hidden_layers = 24
|
||||
num_attention_heads = 16
|
||||
elif checkpoint == "medium":
|
||||
elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
|
||||
hidden_size = 1536
|
||||
num_hidden_layers = 48
|
||||
num_attention_heads = 24
|
||||
elif checkpoint == "large":
|
||||
elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
|
||||
hidden_size = 2048
|
||||
num_hidden_layers = 48
|
||||
num_attention_heads = 32
|
||||
else:
|
||||
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
|
||||
raise ValueError(
|
||||
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
|
||||
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
|
||||
f"for the stereo checkpoints, got {checkpoint}."
|
||||
)
|
||||
|
||||
if "stereo" in checkpoint:
|
||||
audio_channels = 2
|
||||
num_codebooks = 8
|
||||
else:
|
||||
audio_channels = 1
|
||||
num_codebooks = 4
|
||||
|
||||
config = MusicgenDecoderConfig(
|
||||
hidden_size=hidden_size,
|
||||
ffn_dim=hidden_size * 4,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_codebooks=num_codebooks,
|
||||
audio_channels=audio_channels,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
|
||||
def convert_musicgen_checkpoint(
|
||||
checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False
|
||||
):
|
||||
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
|
||||
decoder_config = decoder_config_from_checkpoint(checkpoint)
|
||||
|
||||
@ -146,18 +162,20 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
|
||||
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
|
||||
|
||||
# check we can do a forward pass
|
||||
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
|
||||
decoder_input_ids = input_ids.reshape(2 * 4, -1)
|
||||
input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1)
|
||||
decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||
|
||||
if logits.shape != (8, 1, 2048):
|
||||
if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048):
|
||||
raise ValueError("Incorrect shape for logits")
|
||||
|
||||
# now construct the processor
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels
|
||||
)
|
||||
|
||||
processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
@ -173,12 +191,12 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
|
||||
if pytorch_dump_folder is not None:
|
||||
Path(pytorch_dump_folder).mkdir(exist_ok=True)
|
||||
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
|
||||
model.save_pretrained(pytorch_dump_folder)
|
||||
model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization)
|
||||
processor.save_pretrained(pytorch_dump_folder)
|
||||
|
||||
if repo_id:
|
||||
logger.info(f"Pushing model {checkpoint} to {repo_id}")
|
||||
model.push_to_hub(repo_id)
|
||||
model.push_to_hub(repo_id, safe_serialization=safe_serialization)
|
||||
processor.push_to_hub(repo_id)
|
||||
|
||||
|
||||
@ -189,7 +207,10 @@ if __name__ == "__main__":
|
||||
"--checkpoint",
|
||||
default="small",
|
||||
type=str,
|
||||
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
|
||||
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
|
||||
"`['small', 'medium', 'large']` for the mono checkpoints, or "
|
||||
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
|
||||
"for the stereo checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder",
|
||||
@ -204,6 +225,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
|
||||
|
@ -1077,21 +1077,33 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
|
||||
)
|
||||
|
||||
channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
|
||||
# we only apply the mask if we have a large enough seq len - otherwise we return as is
|
||||
if max_length < 2 * num_codebooks - 1:
|
||||
if max_length < 2 * channel_codebooks - 1:
|
||||
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
|
||||
|
||||
# fill the shifted ids with the prompt entries, offset by the codebook idx
|
||||
for codebook in range(num_codebooks):
|
||||
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
|
||||
for codebook in range(channel_codebooks):
|
||||
if self.config.audio_channels == 1:
|
||||
# mono channel - loop over the codebooks one-by-one
|
||||
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
|
||||
else:
|
||||
# left/right channels are interleaved in the generated codebooks, so handle one then the other
|
||||
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
|
||||
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
|
||||
|
||||
# construct a pattern mask that indicates the positions of padding tokens for each codebook
|
||||
# first fill the upper triangular part (the EOS padding)
|
||||
delay_pattern = torch.triu(
|
||||
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
|
||||
torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
|
||||
)
|
||||
# then fill the lower triangular part (the BOS padding)
|
||||
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
|
||||
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
|
||||
|
||||
if self.config.audio_channels == 2:
|
||||
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
|
||||
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
|
||||
|
||||
mask = ~delay_pattern.to(input_ids.device)
|
||||
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
|
||||
|
||||
@ -1856,6 +1868,11 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
|
||||
"disabled by setting `chunk_length=None` in the audio encoder."
|
||||
)
|
||||
|
||||
if self.config.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
|
||||
# mono input through encodec that we convert to stereo
|
||||
audio_codes = audio_codes.repeat_interleave(2, dim=2)
|
||||
|
||||
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
|
||||
|
||||
# Decode
|
||||
@ -2074,12 +2091,42 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
# 3. make sure that encoder returns `ModelOutput`
|
||||
model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
|
||||
encoder_kwargs["return_dict"] = True
|
||||
encoder_kwargs[model_input_name] = input_values
|
||||
|
||||
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
|
||||
if self.decoder.config.audio_channels == 1:
|
||||
encoder_kwargs[model_input_name] = input_values
|
||||
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
|
||||
audio_codes = audio_encoder_outputs.audio_codes
|
||||
audio_scales = audio_encoder_outputs.audio_scales
|
||||
|
||||
audio_codes = audio_encoder_outputs.audio_codes
|
||||
frames, bsz, codebooks, seq_len = audio_codes.shape
|
||||
frames, bsz, codebooks, seq_len = audio_codes.shape
|
||||
|
||||
else:
|
||||
if input_values.shape[1] != 2:
|
||||
raise ValueError(
|
||||
f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
|
||||
)
|
||||
|
||||
encoder_kwargs[model_input_name] = input_values[:, :1, :]
|
||||
audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
|
||||
audio_codes_left = audio_encoder_outputs_left.audio_codes
|
||||
audio_scales_left = audio_encoder_outputs_left.audio_scales
|
||||
|
||||
encoder_kwargs[model_input_name] = input_values[:, 1:, :]
|
||||
audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
|
||||
audio_codes_right = audio_encoder_outputs_right.audio_codes
|
||||
audio_scales_right = audio_encoder_outputs_right.audio_scales
|
||||
|
||||
frames, bsz, codebooks, seq_len = audio_codes_left.shape
|
||||
# copy alternating left/right channel codes into stereo codebook
|
||||
audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
|
||||
|
||||
audio_codes[:, :, ::2, :] = audio_codes_left
|
||||
audio_codes[:, :, 1::2, :] = audio_codes_right
|
||||
|
||||
if audio_scales_left != [None] or audio_scales_right != [None]:
|
||||
audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
|
||||
else:
|
||||
audio_scales = [None] * bsz
|
||||
|
||||
if frames != 1:
|
||||
raise ValueError(
|
||||
@ -2090,7 +2137,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
|
||||
|
||||
model_kwargs["decoder_input_ids"] = decoder_input_ids
|
||||
model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales
|
||||
model_kwargs["audio_scales"] = audio_scales
|
||||
return model_kwargs
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
@ -2433,16 +2480,25 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
if audio_scales is None:
|
||||
audio_scales = [None] * batch_size
|
||||
|
||||
output_values = self.audio_encoder.decode(
|
||||
output_ids,
|
||||
audio_scales=audio_scales,
|
||||
)
|
||||
if self.decoder.config.audio_channels == 1:
|
||||
output_values = self.audio_encoder.decode(
|
||||
output_ids,
|
||||
audio_scales=audio_scales,
|
||||
).audio_values
|
||||
else:
|
||||
codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
|
||||
output_values_left = codec_outputs_left.audio_values
|
||||
|
||||
codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
|
||||
output_values_right = codec_outputs_right.audio_values
|
||||
|
||||
output_values = torch.cat([output_values_left, output_values_right], dim=1)
|
||||
|
||||
if generation_config.return_dict_in_generate:
|
||||
outputs.sequences = output_values.audio_values
|
||||
outputs.sequences = output_values
|
||||
return outputs
|
||||
else:
|
||||
return output_values.audio_values
|
||||
return output_values
|
||||
|
||||
def get_unconditional_inputs(self, num_samples=1):
|
||||
"""
|
||||
|
@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
|
||||
def prepare_musicgen_inputs_dict(
|
||||
config,
|
||||
@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
||||
@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
output_values.shape == (2, 1, 36480)
|
||||
) # input values take shape 32000 and we generate from there
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
|
||||
|
||||
@require_torch
|
||||
class MusicgenStereoIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def model(self):
|
||||
return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small").to(torch_device)
|
||||
|
||||
@cached_property
|
||||
def processor(self):
|
||||
return MusicgenProcessor.from_pretrained("facebook/musicgen-stereo-small")
|
||||
|
||||
@slow
|
||||
def test_generate_unconditional_greedy(self):
|
||||
model = self.model
|
||||
|
||||
# only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
|
||||
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
|
||||
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
|
||||
|
||||
output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_VALUES_LEFT = torch.tensor(
|
||||
[
|
||||
0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013,
|
||||
-0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099,
|
||||
]
|
||||
)
|
||||
EXPECTED_VALUES_RIGHT = torch.tensor(
|
||||
[
|
||||
0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019,
|
||||
0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (1, 2, 5760))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_generate_text_audio_prompt(self):
|
||||
model = self.model
|
||||
processor = self.processor
|
||||
|
||||
# create stereo inputs
|
||||
audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)]
|
||||
text = ["80s music", "Club techno"]
|
||||
|
||||
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
|
||||
inputs = place_dict_on_device(inputs, device=torch_device)
|
||||
|
||||
output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_VALUES_LEFT = torch.tensor(
|
||||
[
|
||||
0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728,
|
||||
-0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430,
|
||||
]
|
||||
)
|
||||
EXPECTED_VALUES_RIGHT = torch.tensor(
|
||||
[
|
||||
0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103,
|
||||
-0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (2, 2, 37760))
|
||||
# input values take shape 32000 and we generate from there - we check the last (generated) values
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user