mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Add speecht5 batch generation and fix wrong attention mask when padding (#25943)
* fix speecht5 wrong attention mask when padding * enable batch generation and add parameter attention_mask * fix doc * fix format * batch postnet inputs, return batched lengths, and consistent to old api * fix format * fix format * fix the format * fix doc-builder error * add test, cross attention and docstring * optimize code based on reviews * docbuild * refine * not skip slow test * add consistent dropout for batching * loose atol * add another test regarding to the consistency of vocoder * fix format * refactor * add return_concrete_lengths as parameter for consistency w/wo batching * fix review issues * fix cross_attention issue
This commit is contained in:
parent
ee4fb326c7
commit
4309abedbc
@ -674,6 +674,11 @@ class SpeechT5SpeechDecoderPrenet(nn.Module):
|
||||
|
||||
self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size)
|
||||
|
||||
def _consistent_dropout(self, inputs_embeds, p):
|
||||
mask = torch.bernoulli(inputs_embeds[0], p=p)
|
||||
all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1)
|
||||
return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values: torch.Tensor,
|
||||
@ -684,9 +689,7 @@ class SpeechT5SpeechDecoderPrenet(nn.Module):
|
||||
inputs_embeds = input_values
|
||||
for layer in self.layers:
|
||||
inputs_embeds = nn.functional.relu(layer(inputs_embeds))
|
||||
inputs_embeds = nn.functional.dropout(
|
||||
inputs_embeds, self.config.speech_decoder_prenet_dropout, training=True
|
||||
)
|
||||
inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout)
|
||||
|
||||
inputs_embeds = self.final_layer(inputs_embeds)
|
||||
inputs_embeds = self.encode_positions(inputs_embeds)
|
||||
@ -695,6 +698,7 @@ class SpeechT5SpeechDecoderPrenet(nn.Module):
|
||||
speaker_embeddings = nn.functional.normalize(speaker_embeddings)
|
||||
speaker_embeddings = speaker_embeddings.unsqueeze(1)
|
||||
speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
|
||||
speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1)
|
||||
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
|
||||
inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
|
||||
|
||||
@ -2461,11 +2465,13 @@ def _generate_speech(
|
||||
model: SpeechT5PreTrainedModel,
|
||||
input_values: torch.FloatTensor,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
threshold: float = 0.5,
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
return_output_lengths: bool = False,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
if speaker_embeddings is None:
|
||||
raise ValueError(
|
||||
@ -2475,7 +2481,12 @@ def _generate_speech(
|
||||
"""
|
||||
)
|
||||
|
||||
encoder_attention_mask = torch.ones_like(input_values)
|
||||
if attention_mask is None:
|
||||
encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int()
|
||||
else:
|
||||
encoder_attention_mask = attention_mask
|
||||
|
||||
bsz = input_values.size(0)
|
||||
|
||||
encoder_out = model.speecht5.encoder(
|
||||
input_values=input_values,
|
||||
@ -2495,19 +2506,19 @@ def _generate_speech(
|
||||
minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor)
|
||||
|
||||
# Start the output sequence with a mel spectrum that is all zeros.
|
||||
output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins)
|
||||
output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins)
|
||||
|
||||
spectrogram = []
|
||||
cross_attentions = []
|
||||
past_key_values = None
|
||||
idx = 0
|
||||
result_spectrogram = {}
|
||||
|
||||
while True:
|
||||
idx += 1
|
||||
|
||||
# Run the decoder prenet on the entire output sequence.
|
||||
decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
|
||||
|
||||
# Run the decoder layers on the last element of the prenet output.
|
||||
decoder_out = model.speecht5.decoder.wrapped_decoder(
|
||||
hidden_states=decoder_hidden_states[:, -1:],
|
||||
@ -2523,36 +2534,73 @@ def _generate_speech(
|
||||
if output_cross_attentions:
|
||||
cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
|
||||
|
||||
last_decoder_output = decoder_out.last_hidden_state[0, -1]
|
||||
last_decoder_output = decoder_out.last_hidden_state.squeeze(1)
|
||||
past_key_values = decoder_out.past_key_values
|
||||
|
||||
# Predict the new mel spectrum for this step in the sequence.
|
||||
spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
|
||||
spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins)
|
||||
spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
|
||||
spectrogram.append(spectrum)
|
||||
|
||||
# Extend the output sequence with the new mel spectrum.
|
||||
output_sequence = torch.cat((output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1)
|
||||
|
||||
new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)
|
||||
output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1)
|
||||
# Predict the probability that this is the stop token.
|
||||
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
|
||||
|
||||
# Finished when stop token or maximum length is reached.
|
||||
if idx >= minlen and (int(sum(prob >= threshold)) > 0 or idx >= maxlen):
|
||||
spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0)
|
||||
spectrogram = model.speech_decoder_postnet.postnet(spectrogram)
|
||||
spectrogram = spectrogram.squeeze(0)
|
||||
break
|
||||
|
||||
if vocoder is not None:
|
||||
outputs = vocoder(spectrogram)
|
||||
if idx < minlen:
|
||||
continue
|
||||
else:
|
||||
# If the generation loop is less than maximum length time, check the ones in the batch that have met
|
||||
# the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
|
||||
if idx < maxlen:
|
||||
meet_thresholds = torch.sum(prob, dim=-1) >= threshold
|
||||
meet_indexes = torch.where(meet_thresholds)[0].tolist()
|
||||
else:
|
||||
meet_indexes = range(len(prob))
|
||||
meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]
|
||||
if len(meet_indexes) > 0:
|
||||
spectrograms = torch.stack(spectrogram)
|
||||
spectrograms = spectrograms.transpose(0, 1).flatten(1, 2)
|
||||
spectrograms = model.speech_decoder_postnet.postnet(spectrograms)
|
||||
for meet_index in meet_indexes:
|
||||
result_spectrogram[meet_index] = spectrograms[meet_index]
|
||||
if len(result_spectrogram) >= bsz:
|
||||
break
|
||||
spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
|
||||
if not return_output_lengths:
|
||||
spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
|
||||
if vocoder is not None:
|
||||
outputs = vocoder(spectrogram)
|
||||
else:
|
||||
outputs = spectrogram
|
||||
if output_cross_attentions:
|
||||
cross_attentions = torch.cat(cross_attentions, dim=2)
|
||||
if bsz > 1:
|
||||
cross_attentions = cross_attentions.view(
|
||||
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
|
||||
)
|
||||
outputs = (outputs, cross_attentions)
|
||||
else:
|
||||
outputs = spectrogram
|
||||
|
||||
if output_cross_attentions:
|
||||
cross_attentions = torch.cat(cross_attentions, dim=2)
|
||||
outputs = (outputs, cross_attentions)
|
||||
|
||||
# batched return values should also include the spectrogram/waveform lengths
|
||||
spectrogram_lengths = []
|
||||
for i in range(bsz):
|
||||
spectrogram_lengths.append(spectrograms[i].size(0))
|
||||
if vocoder is None:
|
||||
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
|
||||
outputs = (spectrograms, spectrogram_lengths)
|
||||
else:
|
||||
waveforms = []
|
||||
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
|
||||
waveforms = vocoder(spectrograms)
|
||||
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
|
||||
outputs = (waveforms, waveform_lengths)
|
||||
if output_cross_attentions:
|
||||
cross_attentions = torch.cat(cross_attentions, dim=2)
|
||||
cross_attentions = cross_attentions.view(
|
||||
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
|
||||
)
|
||||
outputs = (*outputs, cross_attentions)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -2612,7 +2660,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
|
||||
[`~PreTrainedTokenizer.__call__`] for details.
|
||||
@ -2719,12 +2767,14 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
threshold: float = 0.5,
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
return_output_lengths: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
r"""
|
||||
@ -2733,12 +2783,15 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
|
||||
[`~PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Attention mask from the tokenizer, required for batched inference to signal to the model where to
|
||||
ignore padded tokens from the input_ids.
|
||||
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
|
||||
Tensor containing the speaker embeddings.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
@ -2752,26 +2805,44 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
spectrogram.
|
||||
output_cross_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
|
||||
return_output_lengths (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the concrete spectrogram/waveform lengths.
|
||||
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
|
||||
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
|
||||
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
- when `return_output_lengths` is False
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
|
||||
`torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
|
||||
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
- when `return_output_lengths` is True
|
||||
- **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
|
||||
are padded to the maximum length.
|
||||
- **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of
|
||||
all the concrete lengths for each spectrogram.
|
||||
- **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
|
||||
- **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all
|
||||
the concrete lengths for each waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
|
||||
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
|
||||
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
"""
|
||||
return _generate_speech(
|
||||
self,
|
||||
input_ids,
|
||||
speaker_embeddings,
|
||||
attention_mask,
|
||||
threshold,
|
||||
minlenratio,
|
||||
maxlenratio,
|
||||
vocoder,
|
||||
output_cross_attentions,
|
||||
return_output_lengths,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@ -2779,11 +2850,13 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
threshold: float = 0.5,
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
return_output_lengths: bool = False,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
r"""
|
||||
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
|
||||
@ -2791,7 +2864,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
|
||||
[`~PreTrainedTokenizer.__call__`] for details.
|
||||
@ -2799,6 +2872,14 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
|
||||
Tensor containing the speaker embeddings.
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
|
||||
`[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The generated sequence ends when the predicted stop token probability exceeds this value.
|
||||
minlenratio (`float`, *optional*, defaults to 0.0):
|
||||
@ -2810,26 +2891,44 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
spectrogram.
|
||||
output_cross_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
|
||||
return_output_lengths (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the concrete spectrogram/waveform lengths.
|
||||
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
|
||||
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
|
||||
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
- when `return_output_lengths` is False
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
|
||||
`torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
|
||||
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
- when `return_output_lengths` is True
|
||||
- **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
|
||||
are padded to the maximum length.
|
||||
- **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of
|
||||
all the concrete lengths for each spectrogram.
|
||||
- **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
|
||||
- **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all
|
||||
the concrete lengths for each waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
|
||||
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
|
||||
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
"""
|
||||
return _generate_speech(
|
||||
self,
|
||||
input_ids,
|
||||
speaker_embeddings,
|
||||
attention_mask,
|
||||
threshold,
|
||||
minlenratio,
|
||||
maxlenratio,
|
||||
vocoder,
|
||||
output_cross_attentions,
|
||||
return_output_lengths,
|
||||
)
|
||||
|
||||
|
||||
@ -2988,11 +3087,13 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
self,
|
||||
input_values: torch.FloatTensor,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
threshold: float = 0.5,
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
return_output_lengths: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
|
||||
@ -3000,7 +3101,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
|
||||
Args:
|
||||
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Float values of input raw speech waveform. The `batch_size` should be 1 currently.
|
||||
Float values of input raw speech waveform.
|
||||
|
||||
Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `List[float]` or
|
||||
a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install soundfile*). To prepare the array
|
||||
@ -3008,6 +3109,14 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
|
||||
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
|
||||
Tensor containing the speaker embeddings.
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
|
||||
`[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The generated sequence ends when the predicted stop token probability exceeds this value.
|
||||
minlenratio (`float`, *optional*, defaults to 0.0):
|
||||
@ -3019,16 +3128,32 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
spectrogram.
|
||||
output_cross_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
|
||||
return_output_lengths (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the concrete spectrogram/waveform lengths.
|
||||
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
|
||||
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
|
||||
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
- when `return_output_lengths` is False
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
|
||||
`torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
|
||||
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
- when `return_output_lengths` is True
|
||||
- **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
|
||||
are padded to the maximum length.
|
||||
- **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of
|
||||
all the concrete lengths for each spectrogram.
|
||||
- **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
|
||||
- **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all
|
||||
the concrete lengths for each waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
|
||||
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
|
||||
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
"""
|
||||
if speaker_embeddings is None:
|
||||
speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
|
||||
@ -3037,11 +3162,13 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
self,
|
||||
input_values,
|
||||
speaker_embeddings,
|
||||
attention_mask,
|
||||
threshold,
|
||||
minlenratio,
|
||||
maxlenratio,
|
||||
vocoder,
|
||||
output_cross_attentions,
|
||||
return_output_lengths,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1026,14 +1026,21 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_model(self):
|
||||
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
||||
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
|
||||
|
||||
@cached_property
|
||||
def default_vocoder(self):
|
||||
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
||||
|
||||
def test_generation(self):
|
||||
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
||||
model = self.default_model
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
|
||||
@ -1045,7 +1052,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
|
||||
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins))
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
@ -1053,7 +1060,76 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
generated_speech_with_generate = model.generate(
|
||||
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins))
|
||||
|
||||
def test_batch_generation(self):
|
||||
model = self.default_model
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
vocoder = self.default_vocoder
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
input_text = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
"nor is mister quilter's manner less interesting than his matter",
|
||||
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
|
||||
]
|
||||
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
|
||||
|
||||
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
|
||||
spectrograms, spectrogram_lengths = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
return_output_lengths=True,
|
||||
)
|
||||
self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins))
|
||||
waveforms = vocoder(spectrograms)
|
||||
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
|
||||
|
||||
# Check waveform results are the same with or without using vocder
|
||||
set_seed(555)
|
||||
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
vocoder=vocoder,
|
||||
return_output_lengths=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8))
|
||||
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder)
|
||||
|
||||
# Check waveform results are the same with return_concrete_lengths=True/False
|
||||
set_seed(555)
|
||||
waveforms_with_vocoder_no_lengths = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
vocoder=vocoder,
|
||||
return_output_lengths=False,
|
||||
)
|
||||
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8))
|
||||
|
||||
# Check results when batching are consistent with results without batching
|
||||
for i, text in enumerate(input_text):
|
||||
set_seed(555) # make deterministic
|
||||
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
|
||||
spectrogram = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape)
|
||||
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3))
|
||||
waveform = vocoder(spectrogram)
|
||||
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape)
|
||||
# Check whether waveforms are the same with/without passing vocoder
|
||||
set_seed(555)
|
||||
waveform_with_vocoder = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8))
|
||||
|
||||
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user