TTS fine-tuning for SpeechT5 (#21824)

* wrong argument name

* append eos_token_id

* all tokenizers need mask and ctc_blank tokens

* remove reduction factor from feature extractor

* add proper TTS loss

* did shifting the wrong way around

* mask out padded portions

* remove logits again (don't really need it)

* fix unit tests

* fixup

* pad also returns the decoder attention mask, since that's useful to have

* clean up feature extractor logic

* pad can handle TTS task too

* remove stop_labels from loss calculation

* simplify logic

* fixup

* do -100 masking properly

* small STFT optimization (calculate mel filterbanks only once)

* replace torchaudio fbanks with audio_utils

* remove torchaudio dependency

* simplify & speed up the STFT

* don't serialize window and mel filters

* output cross attentions when generating speech

* add guided attention loss

* fix failing test

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/speecht5/modeling_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change type annotation of attention_mask to LongTensor

* extract loss into class

* remove unused frame_signal_scale argument

* use config object in loss class

* fix type annotations in doc comments

* change optional to just bool

* implement missing tokenizer method

* add deprecation warning

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add deprecation warning for stop_labels

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matthijs Hollemans 2023-04-18 11:12:30 +02:00 committed by GitHub
parent dacd34568d
commit ac2bc50a10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 448 additions and 234 deletions

View File

@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig):
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
The dropout probability for the speech decoder post-net layers. The dropout probability for the speech decoder post-net layers.
reduction_factor (`int`, *optional*, defaults to 2): reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor for the speech decoder post-net. Spectrogram length reduction factor for the speech decoder inputs.
max_speech_positions (`int`, *optional*, defaults to 4000): max_speech_positions (`int`, *optional*, defaults to 4000):
The maximum sequence length of speech features that this model might ever be used with. The maximum sequence length of speech features that this model might ever be used with.
max_text_positions (`int`, *optional*, defaults to 450): max_text_positions (`int`, *optional*, defaults to 450):
The maximum sequence length of text features that this model might ever be used with. The maximum sequence length of text features that this model might ever be used with.
encoder_max_relative_position (`int`, *optional*, defaults to 160): encoder_max_relative_position (`int`, *optional*, defaults to 160):
Maximum distance for relative position embedding in the encoder. Maximum distance for relative position embedding in the encoder.
use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
Whether to apply guided attention loss while training the TTS model.
guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
attention heads.
guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
Standard deviation for guided attention loss.
guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
Scaling coefficient for guided attention loss (also known as lambda).
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
@ -241,6 +250,10 @@ class SpeechT5Config(PretrainedConfig):
max_speech_positions=4000, max_speech_positions=4000,
max_text_positions=450, max_text_positions=450,
encoder_max_relative_position=160, encoder_max_relative_position=160,
use_guided_attention_loss=True,
guided_attention_loss_num_heads=2,
guided_attention_loss_sigma=0.4,
guided_attention_loss_scale=10.0,
use_cache=True, use_cache=True,
is_encoder_decoder=True, is_encoder_decoder=True,
**kwargs, **kwargs,
@ -311,6 +324,12 @@ class SpeechT5Config(PretrainedConfig):
self.max_speech_positions = max_speech_positions self.max_speech_positions = max_speech_positions
self.max_text_positions = max_text_positions self.max_text_positions = max_text_positions
self.encoder_max_relative_position = encoder_max_relative_position self.encoder_max_relative_position = encoder_max_relative_position
self.use_guided_attention_loss = use_guided_attention_loss
self.guided_attention_loss_num_heads = guided_attention_loss_num_heads
self.guided_attention_loss_sigma = guided_attention_loss_sigma
self.guided_attention_loss_scale = guided_attention_loss_scale
self.use_cache = use_cache self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder self.is_encoder_decoder = is_encoder_decoder

View File

@ -351,7 +351,6 @@ def convert_speecht5_checkpoint(
if vocab_path: if vocab_path:
tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions) tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions)
if task == "pretrain":
# Mask token behaves like a normal word, i.e. include the space before it # Mask token behaves like a normal word, i.e. include the space before it
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False) mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
tokenizer.mask_token = mask_token tokenizer.mask_token = mask_token

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for SpeechT5.""" """Feature extractor class for SpeechT5."""
from typing import List, Optional, Union import warnings
from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio
from ...audio_utils import get_mel_filter_banks
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging from ...utils import PaddingStrategy, TensorType, logging
@ -60,7 +61,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
win_function (`str`, *optional*, defaults to `"hann_window"`): win_function (`str`, *optional*, defaults to `"hann_window"`):
Name for the window function used for windowing, must be accessible via `torch.{win_function}` Name for the window function used for windowing, must be accessible via `torch.{win_function}`
frame_signal_scale (`float`, *optional*, defaults to 1.0): frame_signal_scale (`float`, *optional*, defaults to 1.0):
Constant multiplied in creating the frames before applying DFT. Constant multiplied in creating the frames before applying DFT. This argument is deprecated.
fmin (`float`, *optional*, defaults to 80): fmin (`float`, *optional*, defaults to 80):
Minimum mel frequency in Hz. Minimum mel frequency in Hz.
fmax (`float`, *optional*, defaults to 7600): fmax (`float`, *optional*, defaults to 7600):
@ -68,7 +69,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
mel_floor (`float`, *optional*, defaults to 1e-10): mel_floor (`float`, *optional*, defaults to 1e-10):
Minimum value of mel frequency banks. Minimum value of mel frequency banks.
reduction_factor (`int`, *optional*, defaults to 2): reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor. Spectrogram length reduction factor. This argument is deprecated.
return_attention_mask (`bool`, *optional*, defaults to `True`): return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`. Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`.
""" """
@ -109,10 +110,33 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000 self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64)
self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins,
frequency_min=self.fmin,
frequency_max=self.fmax,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
if frame_signal_scale != 1.0:
warnings.warn(
"The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
if reduction_factor != 2.0:
warnings.warn(
"The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
@staticmethod @staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm( def zero_mean_unit_var_norm(
@ -137,99 +161,45 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return normed_input_values return normed_input_values
@staticmethod @staticmethod
def _center_pad(one_waveform, n_fft, pad_mode): def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
padding = [(int(n_fft // 2), int(n_fft // 2))] """
return np.pad(one_waveform, padding, mode=pad_mode) Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size
@staticmethod # promote to float64, since np.fft uses float64 internally
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc waveform = waveform.astype(np.float64)
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
@staticmethod num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal num_frequency_bins = (fft_length // 2) + 1
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride): spectrogram = np.empty((num_frames, num_frequency_bins))
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]
return frames start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length
@staticmethod return spectrogram
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
shaped = frames.reshape(-1, window_length) def _extract_mel_features(
shaped = window * shaped
return shaped
@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
for frame in range(n_frames):
begin = frame * n_samples
inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
dft[frame] = np.abs(out[:K])
return dft
def _extract_fbank_features(
self, self,
one_waveform: np.ndarray, one_waveform: np.ndarray,
) -> np.ndarray: ) -> np.ndarray:
""" """
Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC Extracts log-mel filterbank features for one waveform array (unbatched).
code and librosa.
""" """
one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect") if self.n_fft != self.sample_size:
raise NotImplementedError(
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride) f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
frames = self._frame_signal(
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
) )
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
window = window.numpy()
frames = self._windowing(frames, self.sample_size, window) return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
fbanks = torchaudio.functional.melscale_fbanks(
n_freqs=self.n_freqs,
f_min=self.fmin,
f_max=self.fmax,
n_mels=self.num_mel_bins,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
fbanks = fbanks.numpy()
return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, fbanks)))
def _reduce(self, inputs):
reduced = []
for i in range(len(inputs)):
reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor])
return reduced
def __call__( def __call__(
self, self,
@ -341,7 +311,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return inputs_target return inputs_target
else: else:
inputs["labels"] = inputs_target["input_values"] inputs["labels"] = inputs_target["input_values"]
inputs["stop_labels"] = inputs_target["stop_labels"]
decoder_attention_mask = inputs_target.get("attention_mask") decoder_attention_mask = inputs_target.get("attention_mask")
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask inputs["decoder_attention_mask"] = decoder_attention_mask
@ -381,8 +350,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
# convert into correct format for padding # convert into correct format for padding
if is_target: if is_target:
features = [self._extract_fbank_features(waveform) for waveform in speech] features = [self._extract_mel_features(waveform) for waveform in speech]
fbank_sizes = [len(x) for x in features]
encoded_inputs = BatchFeature({"input_values": features}) encoded_inputs = BatchFeature({"input_values": features})
self.feature_size = self.num_mel_bins self.feature_size = self.num_mel_bins
else: else:
@ -429,22 +397,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
) )
if is_target:
# make labels for stop prediction
stop_labels = []
for i, l in enumerate(fbank_sizes):
labels = np.zeros(len(padded_inputs["input_values"][i]))
labels[l - 1 :] = 1.0
stop_labels.append(labels)
padded_inputs["stop_labels"] = stop_labels
# thin out frames for reduction factor
if self.reduction_factor > 1:
padded_inputs["input_values"] = self._reduce(padded_inputs["input_values"])
if attention_mask is not None:
padded_inputs["attention_mask"] = self._reduce(padded_inputs["attention_mask"])
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs return padded_inputs
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
# Don't serialize these as they are derived from the other properties.
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
for name in names:
if name in output:
del output[name]
return output

View File

@ -16,13 +16,14 @@
import math import math
import random import random
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
@ -72,12 +73,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
def shift_spectrograms_right(input_values: torch.Tensor): def shift_spectrograms_right(input_values: torch.Tensor, reduction_factor: int = 1):
""" """
Shift input spectrograms one timestep to the right. Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
""" """
# thin out frames for reduction factor
if reduction_factor > 1:
input_values = input_values[:, reduction_factor - 1 :: reduction_factor]
shifted_input_values = input_values.new_zeros(input_values.shape) shifted_input_values = input_values.new_zeros(input_values.shape)
shifted_input_values[:, 1:] = input_values[:, :-1].clone() shifted_input_values[:, 1:] = input_values[:, :-1].clone()
# replace possible -100 values in labels by zeros
shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0)
return shifted_input_values return shifted_input_values
@ -565,7 +574,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
def forward( def forward(
self, self,
input_values: torch.Tensor, input_values: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None,
): ):
extract_features = self.feature_encoder(input_values) extract_features = self.feature_encoder(input_values)
@ -840,7 +849,7 @@ class SpeechT5TextDecoderPrenet(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
): ):
if input_ids is not None: if input_ids is not None:
@ -1574,7 +1583,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
def forward( def forward(
self, self,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -1589,7 +1598,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
Features extracted from the speech or text input by the decoder prenet. Features extracted from the speech or text input by the decoder prenet.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
@ -1783,7 +1792,7 @@ class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel):
def forward( def forward(
self, self,
input_values: Optional[torch.FloatTensor] = None, input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeddings: Optional[torch.Tensor] = None, speaker_embeddings: Optional[torch.Tensor] = None,
@ -1837,7 +1846,7 @@ class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel):
def forward( def forward(
self, self,
input_values: Optional[torch.FloatTensor] = None, input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -1884,7 +1893,7 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
def forward( def forward(
self, self,
input_values: Optional[torch.FloatTensor] = None, input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -1911,6 +1920,126 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
return outputs return outputs
class SpeechT5GuidedMultiheadAttentionLoss(nn.Module):
"""
Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
Networks with Guided Attention](https://arxiv.org/abs/1710.08969), adapted for multi-head attention.
"""
def __init__(self, config: SpeechT5Config):
super().__init__()
self.sigma = config.guided_attention_loss_sigma
self.scale = config.guided_attention_loss_scale
def forward(
self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor
) -> torch.Tensor:
"""
Compute the attention loss.
Args:
attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`):
Batch of multi-head attention weights
input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`):
Input attention mask as booleans.
output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`):
Target attention mask as booleans.
Returns:
`torch.Tensor` with the loss value
"""
guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device)
masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2)
masks = masks.to(attentions.device).unsqueeze(1)
losses = guided_attn_masks * attentions
loss = torch.mean(losses.masked_select(masks))
return self.scale * loss
def _make_guided_attention_masks(self, input_masks, output_masks, device):
input_lengths = input_masks.sum(-1)
output_lengths = output_masks.sum(-1)
guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device)
for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)):
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device)
return guided_attn_masks.unsqueeze(1)
@staticmethod
def _make_guided_attention_mask(input_length, output_length, sigma, device):
grid_y, grid_x = torch.meshgrid(
torch.arange(input_length, device=device),
torch.arange(output_length, device=device),
indexing="xy",
)
grid_x = grid_x.float() / output_length
grid_y = grid_y.float() / input_length
return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2)))
class SpeechT5SpectrogramLoss(nn.Module):
"""
Loss computation used by SpeechT5ForTextToSpeech.
"""
def __init__(self, config: SpeechT5Config):
super().__init__()
self.use_guided_attention_loss = config.use_guided_attention_loss
self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads
self.reduction_factor = config.reduction_factor
self.l1_criterion = L1Loss()
self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))
if self.use_guided_attention_loss:
self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config)
def forward(
self,
attention_mask: torch.LongTensor,
outputs_before_postnet: torch.FloatTensor,
outputs_after_postnet: torch.FloatTensor,
logits: torch.FloatTensor,
labels: torch.FloatTensor,
cross_attentions: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
padding_mask = labels != -100.0
# mask out the padded portions
labels = labels.masked_select(padding_mask)
outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask)
outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask)
# spectrogram loss
l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels)
# construct stop labels from the padding mask
masks = padding_mask[:, :, 0]
stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1)
stop_labels = stop_labels[:, 1:].masked_select(masks)
logits = logits.masked_select(masks)
# stop token loss
bce_loss = self.bce_criterion(logits, stop_labels)
# combined loss
loss = l1_loss + bce_loss
# guided attention loss
if self.use_guided_attention_loss:
attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1)
input_masks = attention_mask == 1
output_masks = padding_mask[:, :, 0]
if self.reduction_factor > 1:
output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor]
attn_loss = self.attn_criterion(attn, input_masks, output_masks)
loss += attn_loss
return loss
SPEECHT5_BASE_START_DOCSTRING = r""" SPEECHT5_BASE_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -1981,13 +2110,13 @@ SPEECHT5_INPUTS_DOCSTRING = r"""
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy. information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): head_mask (`torch.FloatTensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): decoder_head_mask (`torch.FloatTensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
@ -2088,27 +2217,27 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_values: Optional[torch.FloatTensor] = None, input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_values: Optional[torch.Tensor] = None, decoder_input_values: Optional[torch.Tensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
speaker_embeddings: Optional[torch.Tensor] = None, speaker_embeddings: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
Depending on which encoder is being used, the `input_values` are either: float values of the input raw Depending on which encoder is being used, the `input_values` are either: float values of the input raw
speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states. speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states.
decoder_input_values (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel
filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in
the vocabulary, or hidden states. the vocabulary, or hidden states.
@ -2246,10 +2375,10 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_values: Optional[torch.Tensor] = None, input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
@ -2259,7 +2388,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Seq2SeqLMOutput]: ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -2414,7 +2543,8 @@ def _generate_speech(
minlenratio: float = 0.0, minlenratio: float = 0.0,
maxlenratio: float = 20.0, maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None, vocoder: Optional[nn.Module] = None,
) -> torch.FloatTensor: output_cross_attentions: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
encoder_attention_mask = torch.ones_like(input_values) encoder_attention_mask = torch.ones_like(input_values)
encoder_out = model.speecht5.encoder( encoder_out = model.speecht5.encoder(
@ -2438,6 +2568,7 @@ def _generate_speech(
output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins) output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins)
spectrogram = [] spectrogram = []
cross_attentions = []
past_key_values = None past_key_values = None
idx = 0 idx = 0
@ -2455,9 +2586,13 @@ def _generate_speech(
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
output_attentions=output_cross_attentions,
return_dict=True, return_dict=True,
) )
if output_cross_attentions:
cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
last_decoder_output = decoder_out.last_hidden_state[0, -1] last_decoder_output = decoder_out.last_hidden_state[0, -1]
past_key_values = decoder_out.past_key_values past_key_values = decoder_out.past_key_values
@ -2480,9 +2615,15 @@ def _generate_speech(
break break
if vocoder is not None: if vocoder is not None:
return vocoder(spectrogram) outputs = vocoder(spectrogram)
else: else:
return spectrogram outputs = spectrogram
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
outputs = (outputs, cross_attentions)
return outputs
@add_start_docstrings( @add_start_docstrings(
@ -2525,10 +2666,10 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_values: Optional[torch.Tensor] = None, decoder_input_values: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
@ -2538,8 +2679,8 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
speaker_embeddings: Optional[torch.Tensor] = None, speaker_embeddings: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.FloatTensor] = None,
stop_labels: Optional[torch.Tensor] = None, stop_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, Seq2SeqSpectrogramOutput]: ) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
r""" r"""
@ -2559,13 +2700,9 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings. Tensor containing the speaker embeddings.
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
[`SpeechT5Processor.__call__`] for details. computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`]
stop_labels (`torch.FloatTensor` of shape `(batch_size, unreduced_sequence_length)`, *optional*): for details.
Labels for computing the stop token loss. Values are 0.0 until the end of the sequence, after which they
become 1.0. The sequence length of this tensor is `config.reduction_factor` times larger than the length of
the target mel spectrogram. Labels can be obtained using [`SpeechT5Processor`]. See
[`SpeechT5Processor.__call__`] for details.
Returns: Returns:
@ -2592,9 +2729,17 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if stop_labels is not None:
warnings.warn(
"The argument `stop_labels` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
if labels is not None: if labels is not None:
if decoder_input_values is None: if decoder_input_values is None:
decoder_input_values = shift_spectrograms_right(labels) decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor)
if self.config.use_guided_attention_loss:
output_attentions = True
outputs = self.speecht5( outputs = self.speecht5(
input_values=input_ids, input_values=input_ids,
@ -2613,17 +2758,27 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
return_dict=True, return_dict=True,
) )
_, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0])
loss = None loss = None
if labels is not None:
criterion = SpeechT5SpectrogramLoss(self.config)
loss = criterion(
attention_mask,
outputs_before_postnet,
outputs_after_postnet,
logits,
labels,
outputs.cross_attentions,
)
if not return_dict: if not return_dict:
output = (spectrogram,) + outputs[1:] output = (outputs_after_postnet,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return Seq2SeqSpectrogramOutput( return Seq2SeqSpectrogramOutput(
loss=loss, loss=loss,
spectrogram=spectrogram, spectrogram=outputs_after_postnet,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states, decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions, decoder_attentions=outputs.decoder_attentions,
@ -2642,7 +2797,8 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
minlenratio: float = 0.0, minlenratio: float = 0.0,
maxlenratio: float = 20.0, maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None, vocoder: Optional[nn.Module] = None,
) -> torch.FloatTensor: output_cross_attentions: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
r""" r"""
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
speech waveform using a vocoder. speech waveform using a vocoder.
@ -2666,10 +2822,18 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
vocoder (`nn.Module`, *optional*, defaults to `None`): vocoder (`nn.Module`, *optional*, defaults to `None`):
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
spectrogram. spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
Returns: Returns:
`torch.FloatTensor`: Tensor of shape `(output_sequence_length, config.num_mel_bins)` containing the `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
""" """
return _generate_speech( return _generate_speech(
self, self,
@ -2679,6 +2843,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
minlenratio, minlenratio,
maxlenratio, maxlenratio,
vocoder, vocoder,
output_cross_attentions,
) )
@ -2723,10 +2888,10 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_values: Optional[torch.Tensor] = None, input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_values: Optional[torch.Tensor] = None, decoder_input_values: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
@ -2736,8 +2901,8 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
speaker_embeddings: Optional[torch.Tensor] = None, speaker_embeddings: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.FloatTensor] = None,
stop_labels: Optional[torch.Tensor] = None, stop_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, Seq2SeqSpectrogramOutput]: ) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
r""" r"""
@ -2757,11 +2922,6 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See
[`SpeechT5Processor.__call__`] for details. [`SpeechT5Processor.__call__`] for details.
stop_labels (`torch.FloatTensor` of shape `(batch_size, unreduced_sequence_length)`, *optional*):
Labels for computing the stop token loss. Values are 0.0 until the end of the sequence, after which they
become 1.0. The sequence length of this tensor is `config.reduction_factor` times larger than the length of
the target mel spectrogram. Labels can be obtained using [`SpeechT5Processor`]. See
[`SpeechT5Processor.__call__`] for details.
Returns: Returns:
@ -2797,9 +2957,15 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if stop_labels is not None:
warnings.warn(
"The argument `stop_labels` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
if labels is not None: if labels is not None:
if decoder_input_values is None: if decoder_input_values is None:
decoder_input_values = shift_spectrograms_right(labels) decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor)
outputs = self.speecht5( outputs = self.speecht5(
input_values=input_values, input_values=input_values,
@ -2847,6 +3013,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
minlenratio: float = 0.0, minlenratio: float = 0.0,
maxlenratio: float = 20.0, maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None, vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r""" r"""
Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
@ -2871,10 +3038,18 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
vocoder (`nn.Module`, *optional*, defaults to `None`): vocoder (`nn.Module`, *optional*, defaults to `None`):
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
spectrogram. spectrogram.
output_cross_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
Returns: Returns:
`torch.FloatTensor`: Tensor of shape `(output_sequence_length, config.num_mel_bins)` containing the `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
`(num_frames,)` -- The predicted speech waveform.
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
""" """
if speaker_embeddings is None: if speaker_embeddings is None:
speaker_embeddings = torch.zeros((1, 512), device=input_values.device) speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
@ -2887,6 +3062,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
minlenratio, minlenratio,
maxlenratio, maxlenratio,
vocoder, vocoder,
output_cross_attentions,
) )
@ -3034,7 +3210,7 @@ class SpeechT5HifiGan(PreTrainedModel):
layer.remove_weight_norm() layer.remove_weight_norm()
nn.utils.remove_weight_norm(self.conv_post) nn.utils.remove_weight_norm(self.conv_post)
def forward(self, spectrogram): def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
r""" r"""
Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech

View File

@ -87,25 +87,21 @@ class SpeechT5Processor(ProcessorMixin):
inputs = None inputs = None
if audio_target is not None: if audio_target is not None:
audio_target_features = self.feature_extractor( targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs)
audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs labels = targets["input_values"]
) elif text_target is not None:
if inputs is None: targets = self.tokenizer(text_target, **kwargs)
return audio_target_features labels = targets["input_ids"]
else: else:
inputs["labels"] = audio_target_features["input_values"] targets = None
inputs["stop_labels"] = audio_target_features["stop_labels"]
decoder_attention_mask = audio_target_features.get("attention_mask")
if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
if text_target is not None:
encodings_target = self.tokenizer(text_target, **kwargs)
if inputs is None: if inputs is None:
return encodings_target return targets
else:
inputs["labels"] = encodings_target["input_ids"] if targets is not None:
decoder_attention_mask = encodings_target.get("attention_mask") inputs["labels"] = labels
decoder_attention_mask = targets.get("attention_mask")
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask inputs["decoder_attention_mask"] = decoder_attention_mask
@ -113,33 +109,63 @@ class SpeechT5Processor(ProcessorMixin):
def pad(self, *args, **kwargs): def pad(self, *args, **kwargs):
""" """
This method forwards all its arguments to SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`] and Collates the audio and text inputs, as well as their targets, into a padded batch.
returns its output.
You can process your labels by using the argument `text` (either in the same call as your audio inputs, or in a Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded
separate call). This forwards its arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`]. by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`].
Valid input combinations are:
- `input_ids` only
- `input_values` only
- `labels` only, either log-mel spectrograms or text tokens
- `input_ids` and log-mel spectrogram `labels`
- `input_values` and text `labels`
Please refer to the docstring of the above two methods for more information. Please refer to the docstring of the above two methods for more information.
""" """
input_features = kwargs.pop("input_features", None) input_values = kwargs.pop("input_values", None)
input_ids = kwargs.pop("input_ids", None)
labels = kwargs.pop("labels", None) labels = kwargs.pop("labels", None)
if len(args) > 0: if input_values is not None and input_ids is not None:
input_features = args[0] raise ValueError("Cannot process both `input_values` and `input_ids` inputs.")
args = args[1:] if input_values is None and input_ids is None and labels is None:
raise ValueError(
"You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded."
)
if input_features is not None: if input_values is not None:
input_features = self.feature_extractor.pad(input_features, *args, **kwargs) inputs = self.feature_extractor.pad(input_values, *args, **kwargs)
if labels is not None: elif input_ids is not None:
labels = self.tokenizer.pad(labels, **kwargs) inputs = self.tokenizer.pad(input_ids, **kwargs)
if labels is None:
return input_features
elif input_features is None:
return labels
else: else:
input_features["labels"] = labels["input_ids"] inputs = None
return input_features
if labels is not None:
if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]):
targets = self.tokenizer.pad(labels, **kwargs)
labels = targets["input_ids"]
else:
feature_size_hack = self.feature_extractor.feature_size
self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins
targets = self.feature_extractor.pad(labels, *args, **kwargs)
self.feature_extractor.feature_size = feature_size_hack
labels = targets["input_values"]
else:
targets = None
if inputs is None:
return targets
if targets is not None:
inputs["labels"] = labels
decoder_attention_mask = targets.get("attention_mask")
if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
return inputs
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """

View File

@ -167,6 +167,26 @@ class SpeechT5Tokenizer(PreTrainedTokenizer):
out_string += self.sp_model.decode(current_sub_tokens) out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip() return out_string.strip()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
suffix_ones = [1]
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + suffix_ones
return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory") logger.error(f"Vocabulary path ({save_directory}) should be a directory")

View File

@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
from transformers import BatchFeature, is_speech_available from transformers import BatchFeature, is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
hop_length=16, hop_length=16,
win_length=64, win_length=64,
win_function="hann_window", win_function="hann_window",
frame_signal_scale=1.0,
fmin=80, fmin=80,
fmax=7600, fmax=7600,
mel_floor=1e-10, mel_floor=1e-10,
reduction_factor=2,
return_attention_mask=True, return_attention_mask=True,
): ):
self.parent = parent self.parent = parent
@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
self.hop_length = hop_length self.hop_length = hop_length
self.win_length = win_length self.win_length = win_length
self.win_function = win_function self.win_function = win_function
self.frame_signal_scale = frame_signal_scale
self.fmin = fmin self.fmin = fmin
self.fmax = fmax self.fmax = fmax
self.mel_floor = mel_floor self.mel_floor = mel_floor
self.reduction_factor = reduction_factor
self.return_attention_mask = return_attention_mask self.return_attention_mask = return_attention_mask
def prepare_feat_extract_dict(self): def prepare_feat_extract_dict(self):
@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
"hop_length": self.hop_length, "hop_length": self.hop_length,
"win_length": self.win_length, "win_length": self.win_length,
"win_function": self.win_function, "win_function": self.win_function,
"frame_signal_scale": self.frame_signal_scale,
"fmin": self.fmin, "fmin": self.fmin,
"fmax": self.fmax, "fmax": self.fmax,
"mel_floor": self.mel_floor, "mel_floor": self.mel_floor,
"reduction_factor": self.reduction_factor,
"return_attention_mask": self.return_attention_mask, "return_attention_mask": self.return_attention_mask,
} }
@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None
@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
def test_integration_target(self): def test_integration_target(self):
# fmt: off # fmt: off
EXPECTED_INPUT_VALUES = torch.tensor( EXPECTED_INPUT_VALUES = torch.tensor(
[-2.7713, -2.8896, -3.2619, -3.0843, -2.9919, -3.0084, -3.2796, -3.3169, [-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777,
-3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314, -3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386,
-3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462, -3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571,
-3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988] -3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998]
) )
# fmt: on # fmt: on

View File

@ -25,10 +25,10 @@ from transformers.testing_utils import (
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
require_torchaudio,
slow, slow,
torch_device, torch_device,
) )
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow @slow
@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow @slow
@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
processor = self.default_processor processor = self.default_processor
set_seed(555) # make deterministic
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids) generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins)) self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
@require_torch @require_torch
@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow @slow

View File

@ -21,7 +21,7 @@ import unittest
from transformers import is_speech_available, is_torch_available from transformers import is_speech_available, is_torch_available
from transformers.models.speecht5 import SpeechT5Tokenizer from transformers.models.speecht5 import SpeechT5Tokenizer
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio from transformers.testing_utils import get_tests_dir, require_torch
from transformers.utils import FEATURE_EXTRACTOR_NAME from transformers.utils import FEATURE_EXTRACTOR_NAME
@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
@require_torch @require_torch
@require_torchaudio
class SpeechT5ProcessorTest(unittest.TestCase): class SpeechT5ProcessorTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
"hop_length": 16, "hop_length": 16,
"win_length": 64, "win_length": 64,
"win_function": "hann_window", "win_function": "hann_window",
"frame_signal_scale": 1.0,
"fmin": 80, "fmin": 80,
"fmax": 7600, "fmax": 7600,
"mel_floor": 1e-10, "mel_floor": 1e-10,

File diff suppressed because one or more lines are too long