mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
TTS fine-tuning for SpeechT5 (#21824)
* wrong argument name * append eos_token_id * all tokenizers need mask and ctc_blank tokens * remove reduction factor from feature extractor * add proper TTS loss * did shifting the wrong way around * mask out padded portions * remove logits again (don't really need it) * fix unit tests * fixup * pad also returns the decoder attention mask, since that's useful to have * clean up feature extractor logic * pad can handle TTS task too * remove stop_labels from loss calculation * simplify logic * fixup * do -100 masking properly * small STFT optimization (calculate mel filterbanks only once) * replace torchaudio fbanks with audio_utils * remove torchaudio dependency * simplify & speed up the STFT * don't serialize window and mel filters * output cross attentions when generating speech * add guided attention loss * fix failing test * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/speecht5/modeling_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change type annotation of attention_mask to LongTensor * extract loss into class * remove unused frame_signal_scale argument * use config object in loss class * fix type annotations in doc comments * change optional to just bool * implement missing tokenizer method * add deprecation warning * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add deprecation warning for stop_labels --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
dacd34568d
commit
ac2bc50a10
@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig):
|
||||
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
|
||||
The dropout probability for the speech decoder post-net layers.
|
||||
reduction_factor (`int`, *optional*, defaults to 2):
|
||||
Spectrogram length reduction factor for the speech decoder post-net.
|
||||
Spectrogram length reduction factor for the speech decoder inputs.
|
||||
max_speech_positions (`int`, *optional*, defaults to 4000):
|
||||
The maximum sequence length of speech features that this model might ever be used with.
|
||||
max_text_positions (`int`, *optional*, defaults to 450):
|
||||
The maximum sequence length of text features that this model might ever be used with.
|
||||
encoder_max_relative_position (`int`, *optional*, defaults to 160):
|
||||
Maximum distance for relative position embedding in the encoder.
|
||||
use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply guided attention loss while training the TTS model.
|
||||
guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
|
||||
Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
|
||||
attention heads.
|
||||
guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
|
||||
Standard deviation for guided attention loss.
|
||||
guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
|
||||
Scaling coefficient for guided attention loss (also known as lambda).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
|
||||
@ -241,6 +250,10 @@ class SpeechT5Config(PretrainedConfig):
|
||||
max_speech_positions=4000,
|
||||
max_text_positions=450,
|
||||
encoder_max_relative_position=160,
|
||||
use_guided_attention_loss=True,
|
||||
guided_attention_loss_num_heads=2,
|
||||
guided_attention_loss_sigma=0.4,
|
||||
guided_attention_loss_scale=10.0,
|
||||
use_cache=True,
|
||||
is_encoder_decoder=True,
|
||||
**kwargs,
|
||||
@ -311,6 +324,12 @@ class SpeechT5Config(PretrainedConfig):
|
||||
self.max_speech_positions = max_speech_positions
|
||||
self.max_text_positions = max_text_positions
|
||||
self.encoder_max_relative_position = encoder_max_relative_position
|
||||
|
||||
self.use_guided_attention_loss = use_guided_attention_loss
|
||||
self.guided_attention_loss_num_heads = guided_attention_loss_num_heads
|
||||
self.guided_attention_loss_sigma = guided_attention_loss_sigma
|
||||
self.guided_attention_loss_scale = guided_attention_loss_scale
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
|
||||
|
@ -351,7 +351,6 @@ def convert_speecht5_checkpoint(
|
||||
if vocab_path:
|
||||
tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions)
|
||||
|
||||
if task == "pretrain":
|
||||
# Mask token behaves like a normal word, i.e. include the space before it
|
||||
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
|
||||
tokenizer.mask_token = mask_token
|
||||
|
@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for SpeechT5."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from ...audio_utils import get_mel_filter_banks
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
@ -60,7 +61,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
win_function (`str`, *optional*, defaults to `"hann_window"`):
|
||||
Name for the window function used for windowing, must be accessible via `torch.{win_function}`
|
||||
frame_signal_scale (`float`, *optional*, defaults to 1.0):
|
||||
Constant multiplied in creating the frames before applying DFT.
|
||||
Constant multiplied in creating the frames before applying DFT. This argument is deprecated.
|
||||
fmin (`float`, *optional*, defaults to 80):
|
||||
Minimum mel frequency in Hz.
|
||||
fmax (`float`, *optional*, defaults to 7600):
|
||||
@ -68,7 +69,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
mel_floor (`float`, *optional*, defaults to 1e-10):
|
||||
Minimum value of mel frequency banks.
|
||||
reduction_factor (`int`, *optional*, defaults to 2):
|
||||
Spectrogram length reduction factor.
|
||||
Spectrogram length reduction factor. This argument is deprecated.
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`.
|
||||
"""
|
||||
@ -109,10 +110,33 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
self.sample_size = win_length * sampling_rate // 1000
|
||||
self.sample_stride = hop_length * sampling_rate // 1000
|
||||
|
||||
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
|
||||
self.n_freqs = (self.n_fft // 2) + 1
|
||||
|
||||
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
|
||||
self.window = window.numpy().astype(np.float64)
|
||||
|
||||
self.mel_filters = get_mel_filter_banks(
|
||||
nb_frequency_bins=self.n_freqs,
|
||||
nb_mel_filters=self.num_mel_bins,
|
||||
frequency_min=self.fmin,
|
||||
frequency_max=self.fmax,
|
||||
sample_rate=self.sampling_rate,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
|
||||
if frame_signal_scale != 1.0:
|
||||
warnings.warn(
|
||||
"The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
if reduction_factor != 2.0:
|
||||
warnings.warn(
|
||||
"The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
||||
def zero_mean_unit_var_norm(
|
||||
@ -137,99 +161,45 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
return normed_input_values
|
||||
|
||||
@staticmethod
|
||||
def _center_pad(one_waveform, n_fft, pad_mode):
|
||||
padding = [(int(n_fft // 2), int(n_fft // 2))]
|
||||
return np.pad(one_waveform, padding, mode=pad_mode)
|
||||
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Calculates the magnitude spectrogram over one waveform array.
|
||||
"""
|
||||
# center pad the waveform
|
||||
padding = [(int(fft_length // 2), int(fft_length // 2))]
|
||||
waveform = np.pad(waveform, padding, mode="reflect")
|
||||
waveform_size = waveform.size
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc
|
||||
def _num_frames_calc(in_size, frame_size, frame_stride):
|
||||
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
|
||||
# promote to float64, since np.fft uses float64 internally
|
||||
waveform = waveform.astype(np.float64)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal
|
||||
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
|
||||
scale = frame_signal_scale
|
||||
frames = np.zeros(n_frames * window_length)
|
||||
for frame_idx in range(n_frames):
|
||||
start = frame_idx * window_length
|
||||
end = (frame_idx + 1) * window_length
|
||||
wave_start = frame_idx * sample_stride
|
||||
wave_end = frame_idx * sample_stride + window_length
|
||||
frames[start:end] = scale * one_waveform[wave_start:wave_end]
|
||||
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
|
||||
num_frequency_bins = (fft_length // 2) + 1
|
||||
spectrogram = np.empty((num_frames, num_frequency_bins))
|
||||
|
||||
return frames
|
||||
start = 0
|
||||
for frame_idx in range(num_frames):
|
||||
frame = waveform[start : start + fft_length] * window
|
||||
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
|
||||
start += hop_length
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing
|
||||
def _windowing(frames, window_length, window):
|
||||
if frames.size % window_length != 0:
|
||||
raise ValueError(
|
||||
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
|
||||
f" window_length={window_length}."
|
||||
)
|
||||
return spectrogram
|
||||
|
||||
shaped = frames.reshape(-1, window_length)
|
||||
shaped = window * shaped
|
||||
return shaped
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft
|
||||
def _dft(frames, K, n_frames, n_samples, n_fft):
|
||||
dft = np.zeros([n_frames, K])
|
||||
|
||||
for frame in range(n_frames):
|
||||
begin = frame * n_samples
|
||||
|
||||
inwards_buffer = frames[begin : begin + n_samples]
|
||||
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
|
||||
out = np.fft.rfft(inwards_buffer)
|
||||
|
||||
dft[frame] = np.abs(out[:K])
|
||||
|
||||
return dft
|
||||
|
||||
def _extract_fbank_features(
|
||||
def _extract_mel_features(
|
||||
self,
|
||||
one_waveform: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC
|
||||
code and librosa.
|
||||
Extracts log-mel filterbank features for one waveform array (unbatched).
|
||||
"""
|
||||
one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect")
|
||||
|
||||
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)
|
||||
|
||||
frames = self._frame_signal(
|
||||
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
|
||||
if self.n_fft != self.sample_size:
|
||||
raise NotImplementedError(
|
||||
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
|
||||
)
|
||||
|
||||
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
|
||||
window = window.numpy()
|
||||
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
|
||||
|
||||
frames = self._windowing(frames, self.sample_size, window)
|
||||
|
||||
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
|
||||
|
||||
fbanks = torchaudio.functional.melscale_fbanks(
|
||||
n_freqs=self.n_freqs,
|
||||
f_min=self.fmin,
|
||||
f_max=self.fmax,
|
||||
n_mels=self.num_mel_bins,
|
||||
sample_rate=self.sampling_rate,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
fbanks = fbanks.numpy()
|
||||
|
||||
return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, fbanks)))
|
||||
|
||||
def _reduce(self, inputs):
|
||||
reduced = []
|
||||
for i in range(len(inputs)):
|
||||
reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor])
|
||||
return reduced
|
||||
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -341,7 +311,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
return inputs_target
|
||||
else:
|
||||
inputs["labels"] = inputs_target["input_values"]
|
||||
inputs["stop_labels"] = inputs_target["stop_labels"]
|
||||
decoder_attention_mask = inputs_target.get("attention_mask")
|
||||
if decoder_attention_mask is not None:
|
||||
inputs["decoder_attention_mask"] = decoder_attention_mask
|
||||
@ -381,8 +350,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
# convert into correct format for padding
|
||||
if is_target:
|
||||
features = [self._extract_fbank_features(waveform) for waveform in speech]
|
||||
fbank_sizes = [len(x) for x in features]
|
||||
features = [self._extract_mel_features(waveform) for waveform in speech]
|
||||
encoded_inputs = BatchFeature({"input_values": features})
|
||||
self.feature_size = self.num_mel_bins
|
||||
else:
|
||||
@ -429,22 +397,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
|
||||
)
|
||||
|
||||
if is_target:
|
||||
# make labels for stop prediction
|
||||
stop_labels = []
|
||||
for i, l in enumerate(fbank_sizes):
|
||||
labels = np.zeros(len(padded_inputs["input_values"][i]))
|
||||
labels[l - 1 :] = 1.0
|
||||
stop_labels.append(labels)
|
||||
padded_inputs["stop_labels"] = stop_labels
|
||||
|
||||
# thin out frames for reduction factor
|
||||
if self.reduction_factor > 1:
|
||||
padded_inputs["input_values"] = self._reduce(padded_inputs["input_values"])
|
||||
if attention_mask is not None:
|
||||
padded_inputs["attention_mask"] = self._reduce(padded_inputs["attention_mask"])
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
output = super().to_dict()
|
||||
|
||||
# Don't serialize these as they are derived from the other properties.
|
||||
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
|
||||
for name in names:
|
||||
if name in output:
|
||||
del output[name]
|
||||
|
||||
return output
|
||||
|
@ -16,13 +16,14 @@
|
||||
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...deepspeed import is_deepspeed_zero3_enabled
|
||||
@ -72,12 +73,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
def shift_spectrograms_right(input_values: torch.Tensor):
|
||||
def shift_spectrograms_right(input_values: torch.Tensor, reduction_factor: int = 1):
|
||||
"""
|
||||
Shift input spectrograms one timestep to the right.
|
||||
Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
|
||||
"""
|
||||
# thin out frames for reduction factor
|
||||
if reduction_factor > 1:
|
||||
input_values = input_values[:, reduction_factor - 1 :: reduction_factor]
|
||||
|
||||
shifted_input_values = input_values.new_zeros(input_values.shape)
|
||||
shifted_input_values[:, 1:] = input_values[:, :-1].clone()
|
||||
|
||||
# replace possible -100 values in labels by zeros
|
||||
shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0)
|
||||
|
||||
return shifted_input_values
|
||||
|
||||
|
||||
@ -565,7 +574,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_values: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
extract_features = self.feature_encoder(input_values)
|
||||
@ -840,7 +849,7 @@ class SpeechT5TextDecoderPrenet(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
):
|
||||
if input_ids is not None:
|
||||
@ -1574,7 +1583,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
@ -1589,7 +1598,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
|
||||
Features extracted from the speech or text input by the decoder prenet.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
@ -1783,7 +1792,7 @@ class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeddings: Optional[torch.Tensor] = None,
|
||||
@ -1837,7 +1846,7 @@ class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
@ -1884,7 +1893,7 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
@ -1911,6 +1920,126 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
class SpeechT5GuidedMultiheadAttentionLoss(nn.Module):
|
||||
"""
|
||||
Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
|
||||
Networks with Guided Attention](https://arxiv.org/abs/1710.08969), adapted for multi-head attention.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SpeechT5Config):
|
||||
super().__init__()
|
||||
self.sigma = config.guided_attention_loss_sigma
|
||||
self.scale = config.guided_attention_loss_scale
|
||||
|
||||
def forward(
|
||||
self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the attention loss.
|
||||
|
||||
Args:
|
||||
attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`):
|
||||
Batch of multi-head attention weights
|
||||
input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`):
|
||||
Input attention mask as booleans.
|
||||
output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`):
|
||||
Target attention mask as booleans.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` with the loss value
|
||||
"""
|
||||
guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device)
|
||||
masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2)
|
||||
masks = masks.to(attentions.device).unsqueeze(1)
|
||||
|
||||
losses = guided_attn_masks * attentions
|
||||
loss = torch.mean(losses.masked_select(masks))
|
||||
return self.scale * loss
|
||||
|
||||
def _make_guided_attention_masks(self, input_masks, output_masks, device):
|
||||
input_lengths = input_masks.sum(-1)
|
||||
output_lengths = output_masks.sum(-1)
|
||||
|
||||
guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device)
|
||||
|
||||
for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)):
|
||||
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device)
|
||||
|
||||
return guided_attn_masks.unsqueeze(1)
|
||||
|
||||
@staticmethod
|
||||
def _make_guided_attention_mask(input_length, output_length, sigma, device):
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.arange(input_length, device=device),
|
||||
torch.arange(output_length, device=device),
|
||||
indexing="xy",
|
||||
)
|
||||
grid_x = grid_x.float() / output_length
|
||||
grid_y = grid_y.float() / input_length
|
||||
return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2)))
|
||||
|
||||
|
||||
class SpeechT5SpectrogramLoss(nn.Module):
|
||||
"""
|
||||
Loss computation used by SpeechT5ForTextToSpeech.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SpeechT5Config):
|
||||
super().__init__()
|
||||
self.use_guided_attention_loss = config.use_guided_attention_loss
|
||||
self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads
|
||||
self.reduction_factor = config.reduction_factor
|
||||
|
||||
self.l1_criterion = L1Loss()
|
||||
self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))
|
||||
|
||||
if self.use_guided_attention_loss:
|
||||
self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.LongTensor,
|
||||
outputs_before_postnet: torch.FloatTensor,
|
||||
outputs_after_postnet: torch.FloatTensor,
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.FloatTensor,
|
||||
cross_attentions: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
padding_mask = labels != -100.0
|
||||
|
||||
# mask out the padded portions
|
||||
labels = labels.masked_select(padding_mask)
|
||||
outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask)
|
||||
outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask)
|
||||
|
||||
# spectrogram loss
|
||||
l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels)
|
||||
|
||||
# construct stop labels from the padding mask
|
||||
masks = padding_mask[:, :, 0]
|
||||
stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1)
|
||||
stop_labels = stop_labels[:, 1:].masked_select(masks)
|
||||
logits = logits.masked_select(masks)
|
||||
|
||||
# stop token loss
|
||||
bce_loss = self.bce_criterion(logits, stop_labels)
|
||||
|
||||
# combined loss
|
||||
loss = l1_loss + bce_loss
|
||||
|
||||
# guided attention loss
|
||||
if self.use_guided_attention_loss:
|
||||
attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1)
|
||||
input_masks = attention_mask == 1
|
||||
output_masks = padding_mask[:, :, 0]
|
||||
if self.reduction_factor > 1:
|
||||
output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
||||
attn_loss = self.attn_criterion(attn, input_masks, output_masks)
|
||||
loss += attn_loss
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
SPEECHT5_BASE_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -1981,13 +2110,13 @@ SPEECHT5_INPUTS_DOCSTRING = r"""
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||
head_mask (`torch.FloatTensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
decoder_head_mask (`torch.FloatTensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
@ -2088,27 +2217,27 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
|
||||
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_values: Optional[torch.Tensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
speaker_embeddings: Optional[torch.Tensor] = None,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
||||
r"""
|
||||
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
Depending on which encoder is being used, the `input_values` are either: float values of the input raw
|
||||
speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states.
|
||||
|
||||
decoder_input_values (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel
|
||||
filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in
|
||||
the vocabulary, or hidden states.
|
||||
@ -2246,10 +2375,10 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.Tensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
@ -2259,7 +2388,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqLMOutput]:
|
||||
r"""
|
||||
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -2414,7 +2543,8 @@ def _generate_speech(
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
) -> torch.FloatTensor:
|
||||
output_cross_attentions: bool = False,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
encoder_attention_mask = torch.ones_like(input_values)
|
||||
|
||||
encoder_out = model.speecht5.encoder(
|
||||
@ -2438,6 +2568,7 @@ def _generate_speech(
|
||||
output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins)
|
||||
|
||||
spectrogram = []
|
||||
cross_attentions = []
|
||||
past_key_values = None
|
||||
idx = 0
|
||||
|
||||
@ -2455,9 +2586,13 @@ def _generate_speech(
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
output_attentions=output_cross_attentions,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if output_cross_attentions:
|
||||
cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
|
||||
|
||||
last_decoder_output = decoder_out.last_hidden_state[0, -1]
|
||||
past_key_values = decoder_out.past_key_values
|
||||
|
||||
@ -2480,9 +2615,15 @@ def _generate_speech(
|
||||
break
|
||||
|
||||
if vocoder is not None:
|
||||
return vocoder(spectrogram)
|
||||
outputs = vocoder(spectrogram)
|
||||
else:
|
||||
return spectrogram
|
||||
outputs = spectrogram
|
||||
|
||||
if output_cross_attentions:
|
||||
cross_attentions = torch.cat(cross_attentions, dim=2)
|
||||
outputs = (outputs, cross_attentions)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -2525,10 +2666,10 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
@replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_values: Optional[torch.Tensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_values: Optional[torch.FloatTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
@ -2538,8 +2679,8 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
speaker_embeddings: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.FloatTensor] = None,
|
||||
stop_labels: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
|
||||
r"""
|
||||
@ -2559,13 +2700,9 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
|
||||
Tensor containing the speaker embeddings.
|
||||
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
|
||||
Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See
|
||||
[`SpeechT5Processor.__call__`] for details.
|
||||
stop_labels (`torch.FloatTensor` of shape `(batch_size, unreduced_sequence_length)`, *optional*):
|
||||
Labels for computing the stop token loss. Values are 0.0 until the end of the sequence, after which they
|
||||
become 1.0. The sequence length of this tensor is `config.reduction_factor` times larger than the length of
|
||||
the target mel spectrogram. Labels can be obtained using [`SpeechT5Processor`]. See
|
||||
[`SpeechT5Processor.__call__`] for details.
|
||||
Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
|
||||
computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`]
|
||||
for details.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -2592,9 +2729,17 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if stop_labels is not None:
|
||||
warnings.warn(
|
||||
"The argument `stop_labels` is deprecated and will be removed in version 4.30.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_values is None:
|
||||
decoder_input_values = shift_spectrograms_right(labels)
|
||||
decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor)
|
||||
if self.config.use_guided_attention_loss:
|
||||
output_attentions = True
|
||||
|
||||
outputs = self.speecht5(
|
||||
input_values=input_ids,
|
||||
@ -2613,17 +2758,27 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
_, spectrogram, logits = self.speech_decoder_postnet(outputs[0])
|
||||
outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
criterion = SpeechT5SpectrogramLoss(self.config)
|
||||
loss = criterion(
|
||||
attention_mask,
|
||||
outputs_before_postnet,
|
||||
outputs_after_postnet,
|
||||
logits,
|
||||
labels,
|
||||
outputs.cross_attentions,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (spectrogram,) + outputs[1:]
|
||||
output = (outputs_after_postnet,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqSpectrogramOutput(
|
||||
loss=loss,
|
||||
spectrogram=spectrogram,
|
||||
spectrogram=outputs_after_postnet,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
@ -2642,7 +2797,8 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
) -> torch.FloatTensor:
|
||||
output_cross_attentions: bool = False,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
r"""
|
||||
Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
|
||||
speech waveform using a vocoder.
|
||||
@ -2666,10 +2822,18 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
vocoder (`nn.Module`, *optional*, defaults to `None`):
|
||||
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
|
||||
spectrogram.
|
||||
output_cross_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Tensor of shape `(output_sequence_length, config.num_mel_bins)` containing the
|
||||
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform.
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
|
||||
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
|
||||
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
"""
|
||||
return _generate_speech(
|
||||
self,
|
||||
@ -2679,6 +2843,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
|
||||
minlenratio,
|
||||
maxlenratio,
|
||||
vocoder,
|
||||
output_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -2723,10 +2888,10 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
@replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_values: Optional[torch.Tensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_values: Optional[torch.FloatTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
@ -2736,8 +2901,8 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
speaker_embeddings: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
speaker_embeddings: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.FloatTensor] = None,
|
||||
stop_labels: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
|
||||
r"""
|
||||
@ -2757,11 +2922,6 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
|
||||
Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See
|
||||
[`SpeechT5Processor.__call__`] for details.
|
||||
stop_labels (`torch.FloatTensor` of shape `(batch_size, unreduced_sequence_length)`, *optional*):
|
||||
Labels for computing the stop token loss. Values are 0.0 until the end of the sequence, after which they
|
||||
become 1.0. The sequence length of this tensor is `config.reduction_factor` times larger than the length of
|
||||
the target mel spectrogram. Labels can be obtained using [`SpeechT5Processor`]. See
|
||||
[`SpeechT5Processor.__call__`] for details.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -2797,9 +2957,15 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if stop_labels is not None:
|
||||
warnings.warn(
|
||||
"The argument `stop_labels` is deprecated and will be removed in version 4.30.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_values is None:
|
||||
decoder_input_values = shift_spectrograms_right(labels)
|
||||
decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor)
|
||||
|
||||
outputs = self.speecht5(
|
||||
input_values=input_values,
|
||||
@ -2847,6 +3013,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
minlenratio: float = 0.0,
|
||||
maxlenratio: float = 20.0,
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
|
||||
@ -2871,10 +3038,18 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
vocoder (`nn.Module`, *optional*, defaults to `None`):
|
||||
The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
|
||||
spectrogram.
|
||||
output_cross_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of the decoder's cross-attention layers.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Tensor of shape `(output_sequence_length, config.num_mel_bins)` containing the
|
||||
predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform.
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
|
||||
- **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
|
||||
- **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
|
||||
`(num_frames,)` -- The predicted speech waveform.
|
||||
- **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor`
|
||||
of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length,
|
||||
input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
|
||||
"""
|
||||
if speaker_embeddings is None:
|
||||
speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
|
||||
@ -2887,6 +3062,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
minlenratio,
|
||||
maxlenratio,
|
||||
vocoder,
|
||||
output_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -3034,7 +3210,7 @@ class SpeechT5HifiGan(PreTrainedModel):
|
||||
layer.remove_weight_norm()
|
||||
nn.utils.remove_weight_norm(self.conv_post)
|
||||
|
||||
def forward(self, spectrogram):
|
||||
def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""
|
||||
Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
|
||||
of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
|
||||
|
@ -87,25 +87,21 @@ class SpeechT5Processor(ProcessorMixin):
|
||||
inputs = None
|
||||
|
||||
if audio_target is not None:
|
||||
audio_target_features = self.feature_extractor(
|
||||
audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs
|
||||
)
|
||||
if inputs is None:
|
||||
return audio_target_features
|
||||
targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs)
|
||||
labels = targets["input_values"]
|
||||
elif text_target is not None:
|
||||
targets = self.tokenizer(text_target, **kwargs)
|
||||
labels = targets["input_ids"]
|
||||
else:
|
||||
inputs["labels"] = audio_target_features["input_values"]
|
||||
inputs["stop_labels"] = audio_target_features["stop_labels"]
|
||||
decoder_attention_mask = audio_target_features.get("attention_mask")
|
||||
if decoder_attention_mask is not None:
|
||||
inputs["decoder_attention_mask"] = decoder_attention_mask
|
||||
targets = None
|
||||
|
||||
if text_target is not None:
|
||||
encodings_target = self.tokenizer(text_target, **kwargs)
|
||||
if inputs is None:
|
||||
return encodings_target
|
||||
else:
|
||||
inputs["labels"] = encodings_target["input_ids"]
|
||||
decoder_attention_mask = encodings_target.get("attention_mask")
|
||||
return targets
|
||||
|
||||
if targets is not None:
|
||||
inputs["labels"] = labels
|
||||
|
||||
decoder_attention_mask = targets.get("attention_mask")
|
||||
if decoder_attention_mask is not None:
|
||||
inputs["decoder_attention_mask"] = decoder_attention_mask
|
||||
|
||||
@ -113,33 +109,63 @@ class SpeechT5Processor(ProcessorMixin):
|
||||
|
||||
def pad(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`] and
|
||||
returns its output.
|
||||
Collates the audio and text inputs, as well as their targets, into a padded batch.
|
||||
|
||||
You can process your labels by using the argument `text` (either in the same call as your audio inputs, or in a
|
||||
separate call). This forwards its arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`].
|
||||
Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded
|
||||
by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`].
|
||||
|
||||
Valid input combinations are:
|
||||
|
||||
- `input_ids` only
|
||||
- `input_values` only
|
||||
- `labels` only, either log-mel spectrograms or text tokens
|
||||
- `input_ids` and log-mel spectrogram `labels`
|
||||
- `input_values` and text `labels`
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
input_values = kwargs.pop("input_values", None)
|
||||
input_ids = kwargs.pop("input_ids", None)
|
||||
labels = kwargs.pop("labels", None)
|
||||
|
||||
if len(args) > 0:
|
||||
input_features = args[0]
|
||||
args = args[1:]
|
||||
if input_values is not None and input_ids is not None:
|
||||
raise ValueError("Cannot process both `input_values` and `input_ids` inputs.")
|
||||
if input_values is None and input_ids is None and labels is None:
|
||||
raise ValueError(
|
||||
"You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded."
|
||||
)
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
|
||||
if labels is not None:
|
||||
labels = self.tokenizer.pad(labels, **kwargs)
|
||||
|
||||
if labels is None:
|
||||
return input_features
|
||||
elif input_features is None:
|
||||
return labels
|
||||
if input_values is not None:
|
||||
inputs = self.feature_extractor.pad(input_values, *args, **kwargs)
|
||||
elif input_ids is not None:
|
||||
inputs = self.tokenizer.pad(input_ids, **kwargs)
|
||||
else:
|
||||
input_features["labels"] = labels["input_ids"]
|
||||
return input_features
|
||||
inputs = None
|
||||
|
||||
if labels is not None:
|
||||
if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]):
|
||||
targets = self.tokenizer.pad(labels, **kwargs)
|
||||
labels = targets["input_ids"]
|
||||
else:
|
||||
feature_size_hack = self.feature_extractor.feature_size
|
||||
self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins
|
||||
targets = self.feature_extractor.pad(labels, *args, **kwargs)
|
||||
self.feature_extractor.feature_size = feature_size_hack
|
||||
labels = targets["input_values"]
|
||||
else:
|
||||
targets = None
|
||||
|
||||
if inputs is None:
|
||||
return targets
|
||||
|
||||
if targets is not None:
|
||||
inputs["labels"] = labels
|
||||
|
||||
decoder_attention_mask = targets.get("attention_mask")
|
||||
if decoder_attention_mask is not None:
|
||||
inputs["decoder_attention_mask"] = decoder_attention_mask
|
||||
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -167,6 +167,26 @@ class SpeechT5Tokenizer(PreTrainedTokenizer):
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
suffix_ones = [1]
|
||||
if token_ids_1 is None:
|
||||
return ([0] * len(token_ids_0)) + suffix_ones
|
||||
return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
|
@ -21,7 +21,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers import BatchFeature, is_speech_available
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
hop_length=16,
|
||||
win_length=64,
|
||||
win_function="hann_window",
|
||||
frame_signal_scale=1.0,
|
||||
fmin=80,
|
||||
fmax=7600,
|
||||
mel_floor=1e-10,
|
||||
reduction_factor=2,
|
||||
return_attention_mask=True,
|
||||
):
|
||||
self.parent = parent
|
||||
@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.win_function = win_function
|
||||
self.frame_signal_scale = frame_signal_scale
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.mel_floor = mel_floor
|
||||
self.reduction_factor = reduction_factor
|
||||
self.return_attention_mask = return_attention_mask
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
"hop_length": self.hop_length,
|
||||
"win_length": self.win_length,
|
||||
"win_function": self.win_function,
|
||||
"frame_signal_scale": self.frame_signal_scale,
|
||||
"fmin": self.fmin,
|
||||
"fmax": self.fmax,
|
||||
"mel_floor": self.mel_floor,
|
||||
"reduction_factor": self.reduction_factor,
|
||||
"return_attention_mask": self.return_attention_mask,
|
||||
}
|
||||
|
||||
@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None
|
||||
|
||||
@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
def test_integration_target(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_VALUES = torch.tensor(
|
||||
[-2.7713, -2.8896, -3.2619, -3.0843, -2.9919, -3.0084, -3.2796, -3.3169,
|
||||
-3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314,
|
||||
-3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462,
|
||||
-3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988]
|
||||
[-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777,
|
||||
-3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386,
|
||||
-3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571,
|
||||
-3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
@ -25,10 +25,10 @@ from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids)
|
||||
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
|
@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
from transformers import is_speech_available, is_torch_available
|
||||
from transformers.models.speecht5 import SpeechT5Tokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
|
||||
@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class SpeechT5ProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
|
||||
"hop_length": 16,
|
||||
"win_length": 64,
|
||||
"win_function": "hann_window",
|
||||
"frame_signal_scale": 1.0,
|
||||
"fmin": 80,
|
||||
"fmax": 7600,
|
||||
"mel_floor": 1e-10,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user