mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[VITS] Fix init test (#25945)
* [VITS] Fix init test * add flaky decorator * style * max attempts Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * style --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
7cd01d4e38
commit
d750eff627
@ -24,6 +24,7 @@ import numpy as np
|
|||||||
|
|
||||||
from transformers import PretrainedConfig, VitsConfig
|
from transformers import PretrainedConfig, VitsConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
is_flaky,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
@ -80,6 +81,10 @@ class VitsModelTester:
|
|||||||
duration_predictor_filter_channels=16,
|
duration_predictor_filter_channels=16,
|
||||||
prior_encoder_num_flows=2,
|
prior_encoder_num_flows=2,
|
||||||
upsample_initial_channel=16,
|
upsample_initial_channel=16,
|
||||||
|
upsample_rates=[8, 2],
|
||||||
|
upsample_kernel_sizes=[16, 4],
|
||||||
|
resblock_kernel_sizes=[3, 7],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -96,6 +101,10 @@ class VitsModelTester:
|
|||||||
self.duration_predictor_filter_channels = duration_predictor_filter_channels
|
self.duration_predictor_filter_channels = duration_predictor_filter_channels
|
||||||
self.prior_encoder_num_flows = prior_encoder_num_flows
|
self.prior_encoder_num_flows = prior_encoder_num_flows
|
||||||
self.upsample_initial_channel = upsample_initial_channel
|
self.upsample_initial_channel = upsample_initial_channel
|
||||||
|
self.upsample_rates = upsample_rates
|
||||||
|
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||||
|
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||||
|
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2)
|
||||||
@ -126,6 +135,10 @@ class VitsModelTester:
|
|||||||
duration_predictor_filter_channels=self.duration_predictor_filter_channels,
|
duration_predictor_filter_channels=self.duration_predictor_filter_channels,
|
||||||
posterior_encoder_num_wavenet_layers=self.num_hidden_layers,
|
posterior_encoder_num_wavenet_layers=self.num_hidden_layers,
|
||||||
upsample_initial_channel=self.upsample_initial_channel,
|
upsample_initial_channel=self.upsample_initial_channel,
|
||||||
|
upsample_rates=self.upsample_rates,
|
||||||
|
upsample_kernel_sizes=self.upsample_kernel_sizes,
|
||||||
|
resblock_kernel_sizes=self.resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes=self.resblock_dilation_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model_forward(self, config, inputs_dict):
|
def create_and_check_model_forward(self, config, inputs_dict):
|
||||||
@ -135,7 +148,7 @@ class VitsModelTester:
|
|||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
|
||||||
result = model(input_ids, attention_mask=attention_mask)
|
result = model(input_ids, attention_mask=attention_mask)
|
||||||
self.parent.assertEqual(result.waveform.shape, (self.batch_size, 11008))
|
self.parent.assertEqual((self.batch_size, 624), result.waveform.shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -168,15 +181,13 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_determinism(self):
|
def test_determinism(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: Fix me (ydshieh)
|
@is_flaky(
|
||||||
@unittest.skip("currently failing")
|
max_attempts=3,
|
||||||
|
description="Weight initialisation for the VITS conv layers sometimes exceeds the kaiming normal range",
|
||||||
|
)
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config)
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
uniform_init_parms = [
|
uniform_init_parms = [
|
||||||
"emb_rel_k",
|
"emb_rel_k",
|
||||||
"emb_rel_v",
|
"emb_rel_v",
|
||||||
@ -192,6 +203,11 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"upsampler",
|
"upsampler",
|
||||||
"resblocks",
|
"resblocks",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any(x in name for x in uniform_init_parms):
|
if any(x in name for x in uniform_init_parms):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
Loading…
Reference in New Issue
Block a user