Add speecht5 batch generation and fix wrong attention mask when padding (#25943)

* fix speecht5 wrong attention mask when padding

* enable batch generation and add parameter attention_mask

* fix doc

* fix format

* batch postnet inputs, return batched lengths, and consistent to old api

* fix format

* fix format

* fix the format

* fix doc-builder error

* add test, cross attention and docstring

* optimize code based on reviews

* docbuild

* refine

* not skip slow test

* add consistent dropout for batching

* loose atol

* add another test regarding to the consistency of vocoder

* fix format

* refactor

* add return_concrete_lengths as parameter for consistency w/wo batching

* fix review issues

* fix cross_attention issue
This commit is contained in:
Sihan Chen 2023-11-14 17:54:09 +08:00 committed by GitHub
parent ee4fb326c7
commit 4309abedbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 257 additions and 54 deletions

View File

@ -674,6 +674,11 @@ class SpeechT5SpeechDecoderPrenet(nn.Module):
self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size)
def _consistent_dropout(self, inputs_embeds, p):
mask = torch.bernoulli(inputs_embeds[0], p=p)
all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1)
return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p)
def forward(
self,
input_values: torch.Tensor,
@ -684,9 +689,7 @@ class SpeechT5SpeechDecoderPrenet(nn.Module):
inputs_embeds = input_values
for layer in self.layers:
inputs_embeds = nn.functional.relu(layer(inputs_embeds))
inputs_embeds = nn.functional.dropout(
inputs_embeds, self.config.speech_decoder_prenet_dropout, training=True
)
inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout)
inputs_embeds = self.final_layer(inputs_embeds)
inputs_embeds = self.encode_positions(inputs_embeds)
@ -695,6 +698,7 @@ class SpeechT5SpeechDecoderPrenet(nn.Module):
speaker_embeddings = nn.functional.normalize(speaker_embeddings)
speaker_embeddings = speaker_embeddings.unsqueeze(1)
speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1)
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
@ -2461,11 +2465,13 @@ def _generate_speech(
model: SpeechT5PreTrainedModel,
input_values: torch.FloatTensor,
speaker_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
return_output_lengths: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
if speaker_embeddings is None:
raise ValueError(
@ -2475,7 +2481,12 @@ def _generate_speech(
"""
)
encoder_attention_mask = torch.ones_like(input_values)
if attention_mask is None:
encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int()
else:
encoder_attention_mask = attention_mask
bsz = input_values.size(0)
encoder_out = model.speecht5.encoder(
input_values=input_values,
@ -2495,19 +2506,19 @@ def _generate_speech(
minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor)
# Start the output sequence with a mel spectrum that is all zeros.
output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins)
output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins)
spectrogram = []
cross_attentions = []
past_key_values = None
idx = 0
result_spectrogram = {}
while True:
idx += 1
# Run the decoder prenet on the entire output sequence.
decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
# Run the decoder layers on the last element of the prenet output.
decoder_out = model.speecht5.decoder.wrapped_decoder(
hidden_states=decoder_hidden_states[:, -1:],
@ -2523,36 +2534,73 @@ def _generate_speech(
if output_cross_attentions:
cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
last_decoder_output = decoder_out.last_hidden_state[0, -1]
last_decoder_output = decoder_out.last_hidden_state.squeeze(1)
past_key_values = decoder_out.past_key_values
# Predict the new mel spectrum for this step in the sequence.
spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins)
spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
spectrogram.append(spectrum)
# Extend the output sequence with the new mel spectrum.
output_sequence = torch.cat((output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1)
new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)
output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1)
# Predict the probability that this is the stop token.
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
# Finished when stop token or maximum length is reached.
if idx >= minlen and (int(sum(prob >= threshold)) > 0 or idx >= maxlen):
spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0)
spectrogram = model.speech_decoder_postnet.postnet(spectrogram)
spectrogram = spectrogram.squeeze(0)
break
if vocoder is not None:
outputs = vocoder(spectrogram)
if idx < minlen:
continue
else:
# If the generation loop is less than maximum length time, check the ones in the batch that have met
# the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
if idx < maxlen:
meet_thresholds = torch.sum(prob, dim=-1) >= threshold
meet_indexes = torch.where(meet_thresholds)[0].tolist()
else:
meet_indexes = range(len(prob))
meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]
if len(meet_indexes) > 0:
spectrograms = torch.stack(spectrogram)
spectrograms = spectrograms.transpose(0, 1).flatten(1, 2)
spectrograms = model.speech_decoder_postnet.postnet(spectrograms)
for meet_index in meet_indexes:
result_spectrogram[meet_index] = spectrograms[meet_index]
if len(result_spectrogram) >= bsz:
break
spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
if not return_output_lengths:
spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
if vocoder is not None:
outputs = vocoder(spectrogram)
else:
outputs = spectrogram
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
if bsz > 1:
cross_attentions = cross_attentions.view(
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
)
outputs = (outputs, cross_attentions)
else:
outputs = spectrogram
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
outputs = (outputs, cross_attentions)
# batched return values should also include the spectrogram/waveform lengths
spectrogram_lengths = []
for i in range(bsz):
spectrogram_lengths.append(spectrograms[i].size(0))
if vocoder is None:
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
outputs = (spectrograms, spectrogram_lengths)
else:
waveforms = []
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
outputs = (waveforms, waveform_lengths)
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
cross_attentions = cross_attentions.view(
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
)
outputs = (*outputs, cross_attentions)
return outputs
@ -2612,7 +2660,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
[`~PreTrainedTokenizer.__call__`] for details.
@ -2719,12 +2767,14 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
speaker_embeddings: Optional[torch.FloatTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
return_output_lengths: bool = False,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
r"""
@ -2733,12 +2783,15 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
[`~PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Attention mask from the tokenizer, required for batched inference to signal to the model where to
ignore padded tokens from the input_ids.
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings.
threshold (`float`, *optional*, defaults to 0.5):
@ -2752,26 +2805,44 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
return_output_lengths (`bool`, *optional*, defaults to `False`):
Whether or not to return the concrete spectrogram/waveform lengths.
Returns:
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
- when `return_output_lengths` is False
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
`torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
- when `return_output_lengths` is True
- **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
are padded to the maximum length.
- **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of
all the concrete lengths for each spectrogram.
- **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
- **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all
the concrete lengths for each waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
return _generate_speech(
self,
input_ids,
speaker_embeddings,
attention_mask,
threshold,
minlenratio,
maxlenratio,
vocoder,
output_cross_attentions,
return_output_lengths,
)
@torch.no_grad()
@ -2779,11 +2850,13 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
self,
input_ids: torch.LongTensor,
speaker_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
return_output_lengths: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
r"""
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
@ -2791,7 +2864,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
[`~PreTrainedTokenizer.__call__`] for details.
@ -2799,6 +2872,14 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
[What are input IDs?](../glossary#input-ids)
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
`[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
threshold (`float`, *optional*, defaults to 0.5):
The generated sequence ends when the predicted stop token probability exceeds this value.
minlenratio (`float`, *optional*, defaults to 0.0):
@ -2810,26 +2891,44 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
return_output_lengths (`bool`, *optional*, defaults to `False`):
Whether or not to return the concrete spectrogram/waveform lengths.
Returns:
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
- when `return_output_lengths` is False
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
`torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
- when `return_output_lengths` is True
- **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
are padded to the maximum length.
- **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of
all the concrete lengths for each spectrogram.
- **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
- **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all
the concrete lengths for each waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
return _generate_speech(
self,
input_ids,
speaker_embeddings,
attention_mask,
threshold,
minlenratio,
maxlenratio,
vocoder,
output_cross_attentions,
return_output_lengths,
)
@ -2988,11 +3087,13 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
self,
input_values: torch.FloatTensor,
speaker_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
return_output_lengths: bool = False,
) -> torch.FloatTensor:
r"""
Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
@ -3000,7 +3101,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. The `batch_size` should be 1 currently.
Float values of input raw speech waveform.
Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `List[float]` or
a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install soundfile*). To prepare the array
@ -3008,6 +3109,14 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
`[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
threshold (`float`, *optional*, defaults to 0.5):
The generated sequence ends when the predicted stop token probability exceeds this value.
minlenratio (`float`, *optional*, defaults to 0.0):
@ -3019,16 +3128,32 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
return_output_lengths (`bool`, *optional*, defaults to `False`):
Whether or not to return the concrete spectrogram/waveform lengths.
Returns:
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
- when `return_output_lengths` is False
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
`torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
- when `return_output_lengths` is True
- **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
are padded to the maximum length.
- **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of
all the concrete lengths for each spectrogram.
- **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
- **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all
the concrete lengths for each waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
if speaker_embeddings is None:
speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
@ -3037,11 +3162,13 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
self,
input_values,
speaker_embeddings,
attention_mask,
threshold,
minlenratio,
maxlenratio,
vocoder,
output_cross_attentions,
return_output_lengths,
)

View File

@ -1026,14 +1026,21 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_sentencepiece
@require_tokenizers
@slow
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
@cached_property
def default_model(self):
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
@cached_property
def default_processor(self):
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
@cached_property
def default_vocoder(self):
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
def test_generation(self):
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
model = self.default_model
model.to(torch_device)
processor = self.default_processor
@ -1045,7 +1052,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins))
set_seed(555) # make deterministic
@ -1053,7 +1060,76 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
generated_speech_with_generate = model.generate(
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
)
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins))
def test_batch_generation(self):
model = self.default_model
model.to(torch_device)
processor = self.default_processor
vocoder = self.default_vocoder
set_seed(555) # make deterministic
input_text = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister quilter's manner less interesting than his matter",
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
]
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
spectrograms, spectrogram_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
return_output_lengths=True,
)
self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins))
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
# Check waveform results are the same with or without using vocder
set_seed(555)
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=True,
)
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8))
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder)
# Check waveform results are the same with return_concrete_lengths=True/False
set_seed(555)
waveforms_with_vocoder_no_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=False,
)
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8))
# Check results when batching are consistent with results without batching
for i, text in enumerate(input_text):
set_seed(555) # make deterministic
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
spectrogram = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
)
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape)
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3))
waveform = vocoder(spectrogram)
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape)
# Check whether waveforms are the same with/without passing vocoder
set_seed(555)
waveform_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
vocoder=vocoder,
)
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8))
@require_torch