add generate method to SpeechT5ForTextToSpeech (#25233)

* add generate method to SpeechT5ForTextToSpeech

* update speecht5forTTS docstrings

* Remove defaults to None in generate docstrings

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe 2023-08-03 15:12:07 +02:00 committed by GitHub
parent 8455346c5c
commit 6d3f9c1e2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 2 deletions

View File

@ -71,7 +71,7 @@ This model was contributed by [Matthijs](https://huggingface.co/Matthijs). The o
[[autodoc]] SpeechT5ForTextToSpeech
- forward
- generate_speech
- generate
## SpeechT5ForSpeechToSpeech

View File

@ -2717,7 +2717,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
>>> set_seed(555) # make deterministic
>>> # generate speech
>>> speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
>>> speech = model.generate(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
>>> speech.shape
torch.Size([15872])
```
@ -2783,6 +2783,65 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
encoder_attentions=outputs.encoder_attentions,
)
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
speaker_embeddings: Optional[torch.FloatTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
r"""
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
speech waveform using a vocoder.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
[`~PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings.
threshold (`float`, *optional*, defaults to 0.5):
The generated sequence ends when the predicted stop token probability exceeds this value.
minlenratio (`float`, *optional*, defaults to 0.0):
Used to calculate the minimum required length for the output sequence.
maxlenratio (`float`, *optional*, defaults to 20.0):
Used to calculate the maximum allowed length for the output sequence.
vocoder (`nn.Module`, *optional*):
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
Returns:
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
return _generate_speech(
self,
input_ids,
speaker_embeddings,
threshold,
minlenratio,
maxlenratio,
vocoder,
output_cross_attentions,
)
@torch.no_grad()
def generate_speech(
self,

View File

@ -1020,6 +1020,10 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
@require_torch
class SpeechT5ForSpeechToSpeechTester: