[VITS] Fix init test (#25945)

* [VITS] Fix init test

* add flaky decorator

* style

* max attempts

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* style

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi 2023-09-04 17:09:26 +01:00 committed by GitHub
parent 7cd01d4e38
commit d750eff627
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,6 +24,7 @@ import numpy as np
from transformers import PretrainedConfig, VitsConfig from transformers import PretrainedConfig, VitsConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky,
is_torch_available, is_torch_available,
require_torch, require_torch,
slow, slow,
@ -80,6 +81,10 @@ class VitsModelTester:
duration_predictor_filter_channels=16, duration_predictor_filter_channels=16,
prior_encoder_num_flows=2, prior_encoder_num_flows=2,
upsample_initial_channel=16, upsample_initial_channel=16,
upsample_rates=[8, 2],
upsample_kernel_sizes=[16, 4],
resblock_kernel_sizes=[3, 7],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -96,6 +101,10 @@ class VitsModelTester:
self.duration_predictor_filter_channels = duration_predictor_filter_channels self.duration_predictor_filter_channels = duration_predictor_filter_channels
self.prior_encoder_num_flows = prior_encoder_num_flows self.prior_encoder_num_flows = prior_encoder_num_flows
self.upsample_initial_channel = upsample_initial_channel self.upsample_initial_channel = upsample_initial_channel
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2)
@ -126,6 +135,10 @@ class VitsModelTester:
duration_predictor_filter_channels=self.duration_predictor_filter_channels, duration_predictor_filter_channels=self.duration_predictor_filter_channels,
posterior_encoder_num_wavenet_layers=self.num_hidden_layers, posterior_encoder_num_wavenet_layers=self.num_hidden_layers,
upsample_initial_channel=self.upsample_initial_channel, upsample_initial_channel=self.upsample_initial_channel,
upsample_rates=self.upsample_rates,
upsample_kernel_sizes=self.upsample_kernel_sizes,
resblock_kernel_sizes=self.resblock_kernel_sizes,
resblock_dilation_sizes=self.resblock_dilation_sizes,
) )
def create_and_check_model_forward(self, config, inputs_dict): def create_and_check_model_forward(self, config, inputs_dict):
@ -135,7 +148,7 @@ class VitsModelTester:
attention_mask = inputs_dict["attention_mask"] attention_mask = inputs_dict["attention_mask"]
result = model(input_ids, attention_mask=attention_mask) result = model(input_ids, attention_mask=attention_mask)
self.parent.assertEqual(result.waveform.shape, (self.batch_size, 11008)) self.parent.assertEqual((self.batch_size, 624), result.waveform.shape)
@require_torch @require_torch
@ -168,15 +181,13 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
def test_determinism(self): def test_determinism(self):
pass pass
# TODO: Fix me (ydshieh) @is_flaky(
@unittest.skip("currently failing") max_attempts=3,
description="Weight initialisation for the VITS conv layers sometimes exceeds the kaiming normal range",
)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [ uniform_init_parms = [
"emb_rel_k", "emb_rel_k",
"emb_rel_v", "emb_rel_v",
@ -192,6 +203,11 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
"upsampler", "upsampler",
"resblocks", "resblocks",
] ]
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
if any(x in name for x in uniform_init_parms): if any(x in name for x in uniform_init_parms):
self.assertTrue( self.assertTrue(