mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
add generate method to SpeechT5ForTextToSpeech (#25233)
* add generate method to SpeechT5ForTextToSpeech * update speecht5forTTS docstrings * Remove defaults to None in generate docstrings Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
8455346c5c
commit
6d3f9c1e2e
@ -71,7 +71,7 @@ This model was contributed by [Matthijs](https://huggingface.co/Matthijs). The o
|
||||
|
||||
[[autodoc]] SpeechT5ForTextToSpeech
|
||||
- forward
|
||||
- generate_speech
|
||||
- generate
|
||||
|
||||
## SpeechT5ForSpeechToSpeech
|
||||
|
||||
|
@ -2717,7 +2717,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
>>> set_seed(555) # make deterministic
|
||||
|
||||
>>> # generate speech
|
||||
>>> speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
|
||||
>>> speech = model.generate(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
|
||||
>>> speech.shape
|
||||
torch.Size([15872])
|
||||
```
|
||||
@ -2783,6 +2783,65 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
threshold: float = 0.5,
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
r"""
|
||||
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
|
||||
speech waveform using a vocoder.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently.
|
||||
|
||||
Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
|
||||
[`~PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
|
||||
Tensor containing the speaker embeddings.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The generated sequence ends when the predicted stop token probability exceeds this value.
|
||||
minlenratio (`float`, *optional*, defaults to 0.0):
|
||||
Used to calculate the minimum required length for the output sequence.
|
||||
maxlenratio (`float`, *optional*, defaults to 20.0):
|
||||
Used to calculate the maximum allowed length for the output sequence.
|
||||
vocoder (`nn.Module`, *optional*):
|
||||
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
|
||||
spectrogram.
|
||||
output_cross_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
|
||||
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
|
||||
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
|
||||
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
"""
|
||||
return _generate_speech(
|
||||
self,
|
||||
input_ids,
|
||||
speaker_embeddings,
|
||||
threshold,
|
||||
minlenratio,
|
||||
maxlenratio,
|
||||
vocoder,
|
||||
output_cross_attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_speech(
|
||||
self,
|
||||
|
@ -1020,6 +1020,10 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
generated_speech = model.generate_speech(input_ids)
|
||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
||||
|
||||
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
|
||||
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
|
||||
|
||||
|
||||
@require_torch
|
||||
class SpeechT5ForSpeechToSpeechTester:
|
||||
|
Loading…
Reference in New Issue
Block a user