diff --git a/src/transformers/models/speecht5/configuration_speecht5.py b/src/transformers/models/speecht5/configuration_speecht5.py index fe5a5ebf149..8d6a61023c7 100644 --- a/src/transformers/models/speecht5/configuration_speecht5.py +++ b/src/transformers/models/speecht5/configuration_speecht5.py @@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig): speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): The dropout probability for the speech decoder post-net layers. reduction_factor (`int`, *optional*, defaults to 2): - Spectrogram length reduction factor for the speech decoder post-net. + Spectrogram length reduction factor for the speech decoder inputs. max_speech_positions (`int`, *optional*, defaults to 4000): The maximum sequence length of speech features that this model might ever be used with. max_text_positions (`int`, *optional*, defaults to 450): The maximum sequence length of text features that this model might ever be used with. encoder_max_relative_position (`int`, *optional*, defaults to 160): Maximum distance for relative position embedding in the encoder. + use_guided_attention_loss (`bool`, *optional*, defaults to `True`): + Whether to apply guided attention loss while training the TTS model. + guided_attention_loss_num_heads (`int`, *optional*, defaults to 2): + Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all + attention heads. + guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4): + Standard deviation for guided attention loss. + guided_attention_loss_scale (`float`, *optional*, defaults to 10.0): + Scaling coefficient for guided attention loss (also known as lambda). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -241,6 +250,10 @@ class SpeechT5Config(PretrainedConfig): max_speech_positions=4000, max_text_positions=450, encoder_max_relative_position=160, + use_guided_attention_loss=True, + guided_attention_loss_num_heads=2, + guided_attention_loss_sigma=0.4, + guided_attention_loss_scale=10.0, use_cache=True, is_encoder_decoder=True, **kwargs, @@ -311,6 +324,12 @@ class SpeechT5Config(PretrainedConfig): self.max_speech_positions = max_speech_positions self.max_text_positions = max_text_positions self.encoder_max_relative_position = encoder_max_relative_position + + self.use_guided_attention_loss = use_guided_attention_loss + self.guided_attention_loss_num_heads = guided_attention_loss_num_heads + self.guided_attention_loss_sigma = guided_attention_loss_sigma + self.guided_attention_loss_scale = guided_attention_loss_scale + self.use_cache = use_cache self.is_encoder_decoder = is_encoder_decoder diff --git a/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py index de1e81fcbd7..20dea800d9d 100644 --- a/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py @@ -351,12 +351,11 @@ def convert_speecht5_checkpoint( if vocab_path: tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions) - if task == "pretrain": - # Mask token behaves like a normal word, i.e. include the space before it - mask_token = AddedToken("", lstrip=True, rstrip=False) - tokenizer.mask_token = mask_token - tokenizer.add_special_tokens({"mask_token": mask_token}) - tokenizer.add_tokens([""]) + # Mask token behaves like a normal word, i.e. include the space before it + mask_token = AddedToken("", lstrip=True, rstrip=False) + tokenizer.mask_token = mask_token + tokenizer.add_special_tokens({"mask_token": mask_token}) + tokenizer.add_tokens([""]) feature_extractor = SpeechT5FeatureExtractor() processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) diff --git a/src/transformers/models/speecht5/feature_extraction_speecht5.py b/src/transformers/models/speecht5/feature_extraction_speecht5.py index 7b4f2af1736..8ceb48dc03c 100644 --- a/src/transformers/models/speecht5/feature_extraction_speecht5.py +++ b/src/transformers/models/speecht5/feature_extraction_speecht5.py @@ -14,12 +14,13 @@ # limitations under the License. """Feature extractor class for SpeechT5.""" -from typing import List, Optional, Union +import warnings +from typing import Any, Dict, List, Optional, Union import numpy as np import torch -import torchaudio +from ...audio_utils import get_mel_filter_banks from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature from ...utils import PaddingStrategy, TensorType, logging @@ -60,7 +61,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): win_function (`str`, *optional*, defaults to `"hann_window"`): Name for the window function used for windowing, must be accessible via `torch.{win_function}` frame_signal_scale (`float`, *optional*, defaults to 1.0): - Constant multiplied in creating the frames before applying DFT. + Constant multiplied in creating the frames before applying DFT. This argument is deprecated. fmin (`float`, *optional*, defaults to 80): Minimum mel frequency in Hz. fmax (`float`, *optional*, defaults to 7600): @@ -68,7 +69,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): mel_floor (`float`, *optional*, defaults to 1e-10): Minimum value of mel frequency banks. reduction_factor (`int`, *optional*, defaults to 2): - Spectrogram length reduction factor. + Spectrogram length reduction factor. This argument is deprecated. return_attention_mask (`bool`, *optional*, defaults to `True`): Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`. """ @@ -109,10 +110,33 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): self.sample_size = win_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000 - self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_freqs = (self.n_fft // 2) + 1 + window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) + self.window = window.numpy().astype(np.float64) + + self.mel_filters = get_mel_filter_banks( + nb_frequency_bins=self.n_freqs, + nb_mel_filters=self.num_mel_bins, + frequency_min=self.fmin, + frequency_max=self.fmax, + sample_rate=self.sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + if frame_signal_scale != 1.0: + warnings.warn( + "The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + if reduction_factor != 2.0: + warnings.warn( + "The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + @staticmethod # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm def zero_mean_unit_var_norm( @@ -137,99 +161,45 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): return normed_input_values @staticmethod - def _center_pad(one_waveform, n_fft, pad_mode): - padding = [(int(n_fft // 2), int(n_fft // 2))] - return np.pad(one_waveform, padding, mode=pad_mode) + def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray: + """ + Calculates the magnitude spectrogram over one waveform array. + """ + # center pad the waveform + padding = [(int(fft_length // 2), int(fft_length // 2))] + waveform = np.pad(waveform, padding, mode="reflect") + waveform_size = waveform.size - @staticmethod - # Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc - def _num_frames_calc(in_size, frame_size, frame_stride): - return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride)) + # promote to float64, since np.fft uses float64 internally + waveform = waveform.astype(np.float64) - @staticmethod - # Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal - def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride): - scale = frame_signal_scale - frames = np.zeros(n_frames * window_length) - for frame_idx in range(n_frames): - start = frame_idx * window_length - end = (frame_idx + 1) * window_length - wave_start = frame_idx * sample_stride - wave_end = frame_idx * sample_stride + window_length - frames[start:end] = scale * one_waveform[wave_start:wave_end] + num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length)) + num_frequency_bins = (fft_length // 2) + 1 + spectrogram = np.empty((num_frames, num_frequency_bins)) - return frames + start = 0 + for frame_idx in range(num_frames): + frame = waveform[start : start + fft_length] * window + spectrogram[frame_idx] = np.abs(np.fft.rfft(frame)) + start += hop_length - @staticmethod - # Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing - def _windowing(frames, window_length, window): - if frames.size % window_length != 0: - raise ValueError( - f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with" - f" window_length={window_length}." - ) + return spectrogram - shaped = frames.reshape(-1, window_length) - shaped = window * shaped - return shaped - - @staticmethod - # Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft - def _dft(frames, K, n_frames, n_samples, n_fft): - dft = np.zeros([n_frames, K]) - - for frame in range(n_frames): - begin = frame * n_samples - - inwards_buffer = frames[begin : begin + n_samples] - inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant") - out = np.fft.rfft(inwards_buffer) - - dft[frame] = np.abs(out[:K]) - - return dft - - def _extract_fbank_features( + def _extract_mel_features( self, one_waveform: np.ndarray, ) -> np.ndarray: """ - Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC - code and librosa. + Extracts log-mel filterbank features for one waveform array (unbatched). """ - one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect") + if self.n_fft != self.sample_size: + raise NotImplementedError( + f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two." + ) - n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride) + stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window) - frames = self._frame_signal( - one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride - ) - - window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) - window = window.numpy() - - frames = self._windowing(frames, self.sample_size, window) - - dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft) - - fbanks = torchaudio.functional.melscale_fbanks( - n_freqs=self.n_freqs, - f_min=self.fmin, - f_max=self.fmax, - n_mels=self.num_mel_bins, - sample_rate=self.sampling_rate, - norm="slaney", - mel_scale="slaney", - ) - fbanks = fbanks.numpy() - - return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, fbanks))) - - def _reduce(self, inputs): - reduced = [] - for i in range(len(inputs)): - reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor]) - return reduced + return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters))) def __call__( self, @@ -341,7 +311,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): return inputs_target else: inputs["labels"] = inputs_target["input_values"] - inputs["stop_labels"] = inputs_target["stop_labels"] decoder_attention_mask = inputs_target.get("attention_mask") if decoder_attention_mask is not None: inputs["decoder_attention_mask"] = decoder_attention_mask @@ -381,8 +350,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): # convert into correct format for padding if is_target: - features = [self._extract_fbank_features(waveform) for waveform in speech] - fbank_sizes = [len(x) for x in features] + features = [self._extract_mel_features(waveform) for waveform in speech] encoded_inputs = BatchFeature({"input_values": features}) self.feature_size = self.num_mel_bins else: @@ -429,22 +397,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value ) - if is_target: - # make labels for stop prediction - stop_labels = [] - for i, l in enumerate(fbank_sizes): - labels = np.zeros(len(padded_inputs["input_values"][i])) - labels[l - 1 :] = 1.0 - stop_labels.append(labels) - padded_inputs["stop_labels"] = stop_labels - - # thin out frames for reduction factor - if self.reduction_factor > 1: - padded_inputs["input_values"] = self._reduce(padded_inputs["input_values"]) - if attention_mask is not None: - padded_inputs["attention_mask"] = self._reduce(padded_inputs["attention_mask"]) - if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) return padded_inputs + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + + # Don't serialize these as they are derived from the other properties. + names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"] + for name in names: + if name in output: + del output[name] + + return output diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index c977414a526..819d8948cab 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -16,13 +16,14 @@ import math import random +import warnings from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled @@ -72,12 +73,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def shift_spectrograms_right(input_values: torch.Tensor): +def shift_spectrograms_right(input_values: torch.Tensor, reduction_factor: int = 1): """ - Shift input spectrograms one timestep to the right. + Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length. """ + # thin out frames for reduction factor + if reduction_factor > 1: + input_values = input_values[:, reduction_factor - 1 :: reduction_factor] + shifted_input_values = input_values.new_zeros(input_values.shape) shifted_input_values[:, 1:] = input_values[:, :-1].clone() + + # replace possible -100 values in labels by zeros + shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0) + return shifted_input_values @@ -565,7 +574,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module): def forward( self, input_values: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, ): extract_features = self.feature_encoder(input_values) @@ -840,7 +849,7 @@ class SpeechT5TextDecoderPrenet(nn.Module): def forward( self, input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, ): if input_ids is not None: @@ -1574,7 +1583,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): def forward( self, hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -1589,7 +1598,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): Features extracted from the speech or text input by the decoder prenet. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, @@ -1783,7 +1792,7 @@ class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel): def forward( self, input_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, speaker_embeddings: Optional[torch.Tensor] = None, @@ -1837,7 +1846,7 @@ class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel): def forward( self, input_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -1884,7 +1893,7 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): def forward( self, input_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -1911,6 +1920,126 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): return outputs +class SpeechT5GuidedMultiheadAttentionLoss(nn.Module): + """ + Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention](https://arxiv.org/abs/1710.08969), adapted for multi-head attention. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.sigma = config.guided_attention_loss_sigma + self.scale = config.guided_attention_loss_scale + + def forward( + self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor + ) -> torch.Tensor: + """ + Compute the attention loss. + + Args: + attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`): + Batch of multi-head attention weights + input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`): + Input attention mask as booleans. + output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`): + Target attention mask as booleans. + + Returns: + `torch.Tensor` with the loss value + """ + guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device) + masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2) + masks = masks.to(attentions.device).unsqueeze(1) + + losses = guided_attn_masks * attentions + loss = torch.mean(losses.masked_select(masks)) + return self.scale * loss + + def _make_guided_attention_masks(self, input_masks, output_masks, device): + input_lengths = input_masks.sum(-1) + output_lengths = output_masks.sum(-1) + + guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device) + + for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device) + + return guided_attn_masks.unsqueeze(1) + + @staticmethod + def _make_guided_attention_mask(input_length, output_length, sigma, device): + grid_y, grid_x = torch.meshgrid( + torch.arange(input_length, device=device), + torch.arange(output_length, device=device), + indexing="xy", + ) + grid_x = grid_x.float() / output_length + grid_y = grid_y.float() / input_length + return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2))) + + +class SpeechT5SpectrogramLoss(nn.Module): + """ + Loss computation used by SpeechT5ForTextToSpeech. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.use_guided_attention_loss = config.use_guided_attention_loss + self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads + self.reduction_factor = config.reduction_factor + + self.l1_criterion = L1Loss() + self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0)) + + if self.use_guided_attention_loss: + self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config) + + def forward( + self, + attention_mask: torch.LongTensor, + outputs_before_postnet: torch.FloatTensor, + outputs_after_postnet: torch.FloatTensor, + logits: torch.FloatTensor, + labels: torch.FloatTensor, + cross_attentions: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + padding_mask = labels != -100.0 + + # mask out the padded portions + labels = labels.masked_select(padding_mask) + outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask) + outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask) + + # spectrogram loss + l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels) + + # construct stop labels from the padding mask + masks = padding_mask[:, :, 0] + stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1) + stop_labels = stop_labels[:, 1:].masked_select(masks) + logits = logits.masked_select(masks) + + # stop token loss + bce_loss = self.bce_criterion(logits, stop_labels) + + # combined loss + loss = l1_loss + bce_loss + + # guided attention loss + if self.use_guided_attention_loss: + attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1) + input_masks = attention_mask == 1 + output_masks = padding_mask[:, :, 0] + if self.reduction_factor > 1: + output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor] + attn_loss = self.attn_criterion(attn, input_masks, output_masks) + loss += attn_loss + + return loss + + SPEECHT5_BASE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1981,13 +2110,13 @@ SPEECHT5_INPUTS_DOCSTRING = r""" and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + head_mask (`torch.FloatTensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + decoder_head_mask (`torch.FloatTensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, @@ -2088,27 +2217,27 @@ class SpeechT5Model(SpeechT5PreTrainedModel): @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + input_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, decoder_input_values: Optional[torch.Tensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, - speaker_embeddings: Optional[torch.Tensor] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" - input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): Depending on which encoder is being used, the `input_values` are either: float values of the input raw speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states. - decoder_input_values (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in the vocabulary, or hidden states. @@ -2246,10 +2375,10 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.Tensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, @@ -2259,7 +2388,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - labels: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -2414,7 +2543,8 @@ def _generate_speech( minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, -) -> torch.FloatTensor: + output_cross_attentions: bool = False, +) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: encoder_attention_mask = torch.ones_like(input_values) encoder_out = model.speecht5.encoder( @@ -2438,6 +2568,7 @@ def _generate_speech( output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins) spectrogram = [] + cross_attentions = [] past_key_values = None idx = 0 @@ -2455,9 +2586,13 @@ def _generate_speech( encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=True, + output_attentions=output_cross_attentions, return_dict=True, ) + if output_cross_attentions: + cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) + last_decoder_output = decoder_out.last_hidden_state[0, -1] past_key_values = decoder_out.past_key_values @@ -2480,9 +2615,15 @@ def _generate_speech( break if vocoder is not None: - return vocoder(spectrogram) + outputs = vocoder(spectrogram) else: - return spectrogram + outputs = spectrogram + + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + outputs = (outputs, cross_attentions) + + return outputs @add_start_docstrings( @@ -2525,10 +2666,10 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): @replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_values: Optional[torch.Tensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, @@ -2538,8 +2679,8 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - speaker_embeddings: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, stop_labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: r""" @@ -2559,13 +2700,9 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): Tensor containing the speaker embeddings. labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): - Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See - [`SpeechT5Processor.__call__`] for details. - stop_labels (`torch.FloatTensor` of shape `(batch_size, unreduced_sequence_length)`, *optional*): - Labels for computing the stop token loss. Values are 0.0 until the end of the sequence, after which they - become 1.0. The sequence length of this tensor is `config.reduction_factor` times larger than the length of - the target mel spectrogram. Labels can be obtained using [`SpeechT5Processor`]. See - [`SpeechT5Processor.__call__`] for details. + Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss + computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`] + for details. Returns: @@ -2592,9 +2729,17 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if stop_labels is not None: + warnings.warn( + "The argument `stop_labels` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + if labels is not None: if decoder_input_values is None: - decoder_input_values = shift_spectrograms_right(labels) + decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor) + if self.config.use_guided_attention_loss: + output_attentions = True outputs = self.speecht5( input_values=input_ids, @@ -2613,17 +2758,27 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): return_dict=True, ) - _, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) + outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0]) loss = None + if labels is not None: + criterion = SpeechT5SpectrogramLoss(self.config) + loss = criterion( + attention_mask, + outputs_before_postnet, + outputs_after_postnet, + logits, + labels, + outputs.cross_attentions, + ) if not return_dict: - output = (spectrogram,) + outputs[1:] + output = (outputs_after_postnet,) + outputs[1:] return ((loss,) + output) if loss is not None else output return Seq2SeqSpectrogramOutput( loss=loss, - spectrogram=spectrogram, + spectrogram=outputs_after_postnet, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, @@ -2642,7 +2797,8 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, - ) -> torch.FloatTensor: + output_cross_attentions: bool = False, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: r""" Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a speech waveform using a vocoder. @@ -2666,10 +2822,18 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): vocoder (`nn.Module`, *optional*, defaults to `None`): The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. Returns: - `torch.FloatTensor`: Tensor of shape `(output_sequence_length, config.num_mel_bins)` containing the - predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform. + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor` + of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, + input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ return _generate_speech( self, @@ -2679,6 +2843,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): minlenratio, maxlenratio, vocoder, + output_cross_attentions, ) @@ -2723,10 +2888,10 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): @replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_values: Optional[torch.Tensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, @@ -2736,8 +2901,8 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - speaker_embeddings: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, stop_labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: r""" @@ -2757,11 +2922,6 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`] for details. - stop_labels (`torch.FloatTensor` of shape `(batch_size, unreduced_sequence_length)`, *optional*): - Labels for computing the stop token loss. Values are 0.0 until the end of the sequence, after which they - become 1.0. The sequence length of this tensor is `config.reduction_factor` times larger than the length of - the target mel spectrogram. Labels can be obtained using [`SpeechT5Processor`]. See - [`SpeechT5Processor.__call__`] for details. Returns: @@ -2797,9 +2957,15 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if stop_labels is not None: + warnings.warn( + "The argument `stop_labels` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + if labels is not None: if decoder_input_values is None: - decoder_input_values = shift_spectrograms_right(labels) + decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor) outputs = self.speecht5( input_values=input_values, @@ -2847,6 +3013,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, ) -> torch.FloatTensor: r""" Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a @@ -2871,10 +3038,18 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): vocoder (`nn.Module`, *optional*, defaults to `None`): The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. Returns: - `torch.FloatTensor`: Tensor of shape `(output_sequence_length, config.num_mel_bins)` containing the - predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform. + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor` + of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, + input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ if speaker_embeddings is None: speaker_embeddings = torch.zeros((1, 512), device=input_values.device) @@ -2887,6 +3062,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): minlenratio, maxlenratio, vocoder, + output_cross_attentions, ) @@ -3034,7 +3210,7 @@ class SpeechT5HifiGan(PreTrainedModel): layer.remove_weight_norm() nn.utils.remove_weight_norm(self.conv_post) - def forward(self, spectrogram): + def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor: r""" Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech diff --git a/src/transformers/models/speecht5/processing_speecht5.py b/src/transformers/models/speecht5/processing_speecht5.py index 3cb66498337..27353b4702b 100644 --- a/src/transformers/models/speecht5/processing_speecht5.py +++ b/src/transformers/models/speecht5/processing_speecht5.py @@ -87,59 +87,85 @@ class SpeechT5Processor(ProcessorMixin): inputs = None if audio_target is not None: - audio_target_features = self.feature_extractor( - audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs - ) - if inputs is None: - return audio_target_features - else: - inputs["labels"] = audio_target_features["input_values"] - inputs["stop_labels"] = audio_target_features["stop_labels"] - decoder_attention_mask = audio_target_features.get("attention_mask") - if decoder_attention_mask is not None: - inputs["decoder_attention_mask"] = decoder_attention_mask + targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs) + labels = targets["input_values"] + elif text_target is not None: + targets = self.tokenizer(text_target, **kwargs) + labels = targets["input_ids"] + else: + targets = None - if text_target is not None: - encodings_target = self.tokenizer(text_target, **kwargs) - if inputs is None: - return encodings_target - else: - inputs["labels"] = encodings_target["input_ids"] - decoder_attention_mask = encodings_target.get("attention_mask") - if decoder_attention_mask is not None: - inputs["decoder_attention_mask"] = decoder_attention_mask + if inputs is None: + return targets + + if targets is not None: + inputs["labels"] = labels + + decoder_attention_mask = targets.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask return inputs def pad(self, *args, **kwargs): """ - This method forwards all its arguments to SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`] and - returns its output. + Collates the audio and text inputs, as well as their targets, into a padded batch. - You can process your labels by using the argument `text` (either in the same call as your audio inputs, or in a - separate call). This forwards its arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`]. + Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded + by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`]. + + Valid input combinations are: + + - `input_ids` only + - `input_values` only + - `labels` only, either log-mel spectrograms or text tokens + - `input_ids` and log-mel spectrogram `labels` + - `input_values` and text `labels` Please refer to the docstring of the above two methods for more information. """ - input_features = kwargs.pop("input_features", None) + input_values = kwargs.pop("input_values", None) + input_ids = kwargs.pop("input_ids", None) labels = kwargs.pop("labels", None) - if len(args) > 0: - input_features = args[0] - args = args[1:] + if input_values is not None and input_ids is not None: + raise ValueError("Cannot process both `input_values` and `input_ids` inputs.") + if input_values is None and input_ids is None and labels is None: + raise ValueError( + "You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded." + ) - if input_features is not None: - input_features = self.feature_extractor.pad(input_features, *args, **kwargs) - if labels is not None: - labels = self.tokenizer.pad(labels, **kwargs) - - if labels is None: - return input_features - elif input_features is None: - return labels + if input_values is not None: + inputs = self.feature_extractor.pad(input_values, *args, **kwargs) + elif input_ids is not None: + inputs = self.tokenizer.pad(input_ids, **kwargs) else: - input_features["labels"] = labels["input_ids"] - return input_features + inputs = None + + if labels is not None: + if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]): + targets = self.tokenizer.pad(labels, **kwargs) + labels = targets["input_ids"] + else: + feature_size_hack = self.feature_extractor.feature_size + self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins + targets = self.feature_extractor.pad(labels, *args, **kwargs) + self.feature_extractor.feature_size = feature_size_hack + labels = targets["input_values"] + else: + targets = None + + if inputs is None: + return targets + + if targets is not None: + inputs["labels"] = labels + + decoder_attention_mask = targets.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/speecht5/tokenization_speecht5.py b/src/transformers/models/speecht5/tokenization_speecht5.py index a0b933f3056..9f93be5ecdc 100644 --- a/src/transformers/models/speecht5/tokenization_speecht5.py +++ b/src/transformers/models/speecht5/tokenization_speecht5.py @@ -167,6 +167,26 @@ class SpeechT5Tokenizer(PreTrainedTokenizer): out_string += self.sp_model.decode(current_sub_tokens) return out_string.strip() + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + suffix_ones = [1] + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + suffix_ones + return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") diff --git a/tests/models/speecht5/test_feature_extraction_speecht5.py b/tests/models/speecht5/test_feature_extraction_speecht5.py index 390b769b8db..d19c71dd56f 100644 --- a/tests/models/speecht5/test_feature_extraction_speecht5.py +++ b/tests/models/speecht5/test_feature_extraction_speecht5.py @@ -21,7 +21,7 @@ import unittest import numpy as np from transformers import BatchFeature, is_speech_available -from transformers.testing_utils import require_torch, require_torchaudio +from transformers.testing_utils import require_torch from transformers.utils.import_utils import is_torch_available from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin @@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): hop_length=16, win_length=64, win_function="hann_window", - frame_signal_scale=1.0, fmin=80, fmax=7600, mel_floor=1e-10, - reduction_factor=2, return_attention_mask=True, ): self.parent = parent @@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): self.hop_length = hop_length self.win_length = win_length self.win_function = win_function - self.frame_signal_scale = frame_signal_scale self.fmin = fmin self.fmax = fmax self.mel_floor = mel_floor - self.reduction_factor = reduction_factor self.return_attention_mask = return_attention_mask def prepare_feat_extract_dict(self): @@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): "hop_length": self.hop_length, "win_length": self.win_length, "win_function": self.win_function, - "frame_signal_scale": self.frame_signal_scale, "fmin": self.fmin, "fmax": self.fmax, "mel_floor": self.mel_floor, - "reduction_factor": self.reduction_factor, "return_attention_mask": self.return_attention_mask, } @@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): @require_torch -@require_torchaudio class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None @@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest def test_integration_target(self): # fmt: off EXPECTED_INPUT_VALUES = torch.tensor( - [-2.7713, -2.8896, -3.2619, -3.0843, -2.9919, -3.0084, -3.2796, -3.3169, - -3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314, - -3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462, - -3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988] + [-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777, + -3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386, + -3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571, + -3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998] ) # fmt: on diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 028a4c50df3..8fbbee84f22 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -25,10 +25,10 @@ from transformers.testing_utils import ( require_sentencepiece, require_tokenizers, require_torch, - require_torchaudio, slow, torch_device, ) +from transformers.trainer_utils import set_seed from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase): @require_torch -@require_torchaudio @require_sentencepiece @require_tokenizers @slow @@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): @require_torch -@require_torchaudio @require_sentencepiece @require_tokenizers @slow @@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): model.to(torch_device) processor = self.default_processor + set_seed(555) # make deterministic + input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) generated_speech = model.generate_speech(input_ids) - self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins)) + self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins)) @require_torch @@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase): @require_torch -@require_torchaudio @require_sentencepiece @require_tokenizers @slow diff --git a/tests/models/speecht5/test_processor_speecht5.py b/tests/models/speecht5/test_processor_speecht5.py index d3f28738cb7..97d3842f105 100644 --- a/tests/models/speecht5/test_processor_speecht5.py +++ b/tests/models/speecht5/test_processor_speecht5.py @@ -21,7 +21,7 @@ import unittest from transformers import is_speech_available, is_torch_available from transformers.models.speecht5 import SpeechT5Tokenizer -from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio +from transformers.testing_utils import get_tests_dir, require_torch from transformers.utils import FEATURE_EXTRACTOR_NAME @@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model") @require_torch -@require_torchaudio class SpeechT5ProcessorTest(unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() @@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase): "hop_length": 16, "win_length": 64, "win_function": "hann_window", - "frame_signal_scale": 1.0, "fmin": 80, "fmax": 7600, "mel_floor": 1e-10, diff --git a/tests/models/speecht5/test_tokenization_speecht5.py b/tests/models/speecht5/test_tokenization_speecht5.py index ec8dca122ff..8e23e233bae 100644 --- a/tests/models/speecht5/test_tokenization_speecht5.py +++ b/tests/models/speecht5/test_tokenization_speecht5.py @@ -19,6 +19,7 @@ import unittest from transformers import SPIECE_UNDERLINE from transformers.models.speecht5 import SpeechT5Tokenizer from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow +from transformers.tokenization_utils import AddedToken from ...test_tokenization_common import TokenizerTesterMixin @@ -38,6 +39,12 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): # We have a SentencePiece fixture for testing tokenizer = SpeechT5Tokenizer(SAMPLE_VOCAB) + + mask_token = AddedToken("", lstrip=True, rstrip=False) + tokenizer.mask_token = mask_token + tokenizer.add_special_tokens({"mask_token": mask_token}) + tokenizer.add_tokens([""]) + tokenizer.save_pretrained(self.tmpdirname) def get_input_output_texts(self, tokenizer): @@ -64,8 +71,10 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(vocab_keys[0], "") self.assertEqual(vocab_keys[1], "") - self.assertEqual(vocab_keys[-2], "œ") - self.assertEqual(len(vocab_keys), 79) + self.assertEqual(vocab_keys[-4], "œ") + self.assertEqual(vocab_keys[-2], "") + self.assertEqual(vocab_keys[-1], "") + self.assertEqual(len(vocab_keys), 81) def test_vocab_size(self): self.assertEqual(self.get_tokenizer().vocab_size, 79) @@ -175,7 +184,18 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): ] # fmt: off - expected_encoding = {'input_ids': [[4, 32, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 64, 19, 8, 13, 18, 5, 13, 15, 22, 4, 28, 9, 8, 20, 9, 4, 7, 12, 4, 24, 22, 6, 8, 13, 17, 11, 39, 6, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 7, 9, 14, 4, 24, 22, 6, 8, 13, 17, 11, 39, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 39, 25, 5, 13, 6, 63, 4, 24, 13, 8, 27, 10, 14, 5, 12, 4, 21, 5, 9, 5, 13, 7, 15, 39, 24, 16, 13, 24, 8, 12, 5, 4, 7, 13, 17, 11, 10, 6, 5, 17, 6, 16, 13, 5, 12, 4, 64, 40, 47, 54, 32, 23, 4, 53, 49, 32, 23, 4, 54, 8, 40, 47, 54, 32, 7, 23, 4, 69, 52, 43, 23, 4, 51, 10, 12, 6, 10, 15, 40, 5, 13, 6, 23, 4, 69, 52, 48, 5, 6, 26, 26, 26, 63, 4, 19, 8, 13, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 61, 9, 14, 5, 13, 12, 6, 7, 9, 14, 10, 9, 21, 4, 64, 48, 52, 61, 63, 4, 7, 9, 14, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 53, 5, 9, 5, 13, 7, 6, 10, 8, 9, 4, 64, 48, 52, 53, 63, 4, 20, 10, 6, 11, 4, 8, 27, 5, 13, 4, 6, 11, 10, 13, 6, 22, 39, 6, 20, 8, 4, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 4, 18, 8, 14, 5, 15, 12, 4, 10, 9, 4, 8, 9, 5, 4, 11, 16, 9, 14, 13, 5, 14, 4, 24, 15, 16, 12, 4, 15, 7, 9, 21, 16, 7, 21, 5, 12, 4, 7, 9, 14, 4, 14, 5, 5, 24, 4, 10, 9, 6, 5, 13, 8, 24, 5, 13, 7, 25, 10, 15, 10, 6, 22, 4, 25, 5, 6, 20, 5, 5, 9, 4, 58, 7, 37, 23, 4, 49, 22, 32, 8, 13, 17, 11, 4, 7, 9, 14, 4, 32, 5, 9, 12, 8, 13, 55, 15, 8, 20, 26], [4, 40, 47, 54, 32, 4, 10, 12, 4, 14, 5, 12, 10, 21, 9, 5, 14, 4, 6, 8, 4, 24, 13, 5, 39, 6, 13, 7, 10, 9, 4, 14, 5, 5, 24, 4, 25, 10, 14, 10, 13, 5, 17, 6, 10, 8, 9, 7, 15, 4, 13, 5, 24, 13, 5, 12, 5, 9, 6, 7, 6, 10, 8, 9, 12, 4, 19, 13, 8, 18, 4, 16, 9, 15, 7, 25, 5, 15, 5, 14, 4, 6, 5, 37, 6, 4, 25, 22, 4, 46, 8, 10, 9, 6, 15, 22, 4, 17, 8, 9, 14, 10, 6, 10, 8, 9, 10, 9, 21, 4, 8, 9, 4, 25, 8, 6, 11, 4, 15, 5, 19, 6, 4, 7, 9, 14, 4, 13, 10, 21, 11, 6, 4, 17, 8, 9, 6, 5, 37, 6, 4, 10, 9, 4, 7, 15, 15, 4, 15, 7, 22, 5, 13, 12, 26, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [4, 32, 11, 5, 4, 45, 16, 10, 17, 28, 4, 25, 13, 8, 20, 9, 4, 19, 8, 37, 4, 46, 16, 18, 24, 12, 4, 8, 27, 5, 13, 4, 6, 11, 5, 4, 15, 7, 57, 22, 4, 14, 8, 21, 26, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} + expected_encoding = { + 'input_ids': [ + [4, 32, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 64, 19, 8, 13, 18, 5, 13, 15, 22, 4, 28, 9, 8, 20, 9, 4, 7, 12, 4, 24, 22, 6, 8, 13, 17, 11, 39, 6, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 7, 9, 14, 4, 24, 22, 6, 8, 13, 17, 11, 39, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 39, 25, 5, 13, 6, 63, 4, 24, 13, 8, 27, 10, 14, 5, 12, 4, 21, 5, 9, 5, 13, 7, 15, 39, 24, 16, 13, 24, 8, 12, 5, 4, 7, 13, 17, 11, 10, 6, 5, 17, 6, 16, 13, 5, 12, 4, 64, 40, 47, 54, 32, 23, 4, 53, 49, 32, 23, 4, 54, 8, 40, 47, 54, 32, 7, 23, 4, 69, 52, 43, 23, 4, 51, 10, 12, 6, 10, 15, 40, 5, 13, 6, 23, 4, 69, 52, 48, 5, 6, 26, 26, 26, 63, 4, 19, 8, 13, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 61, 9, 14, 5, 13, 12, 6, 7, 9, 14, 10, 9, 21, 4, 64, 48, 52, 61, 63, 4, 7, 9, 14, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 53, 5, 9, 5, 13, 7, 6, 10, 8, 9, 4, 64, 48, 52, 53, 63, 4, 20, 10, 6, 11, 4, 8, 27, 5, 13, 4, 6, 11, 10, 13, 6, 22, 39, 6, 20, 8, 4, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 4, 18, 8, 14, 5, 15, 12, 4, 10, 9, 4, 8, 9, 5, 4, 11, 16, 9, 14, 13, 5, 14, 4, 24, 15, 16, 12, 4, 15, 7, 9, 21, 16, 7, 21, 5, 12, 4, 7, 9, 14, 4, 14, 5, 5, 24, 4, 10, 9, 6, 5, 13, 8, 24, 5, 13, 7, 25, 10, 15, 10, 6, 22, 4, 25, 5, 6, 20, 5, 5, 9, 4, 58, 7, 37, 23, 4, 49, 22, 32, 8, 13, 17, 11, 4, 7, 9, 14, 4, 32, 5, 9, 12, 8, 13, 55, 15, 8, 20, 26, 2], + [4, 40, 47, 54, 32, 4, 10, 12, 4, 14, 5, 12, 10, 21, 9, 5, 14, 4, 6, 8, 4, 24, 13, 5, 39, 6, 13, 7, 10, 9, 4, 14, 5, 5, 24, 4, 25, 10, 14, 10, 13, 5, 17, 6, 10, 8, 9, 7, 15, 4, 13, 5, 24, 13, 5, 12, 5, 9, 6, 7, 6, 10, 8, 9, 12, 4, 19, 13, 8, 18, 4, 16, 9, 15, 7, 25, 5, 15, 5, 14, 4, 6, 5, 37, 6, 4, 25, 22, 4, 46, 8, 10, 9, 6, 15, 22, 4, 17, 8, 9, 14, 10, 6, 10, 8, 9, 10, 9, 21, 4, 8, 9, 4, 25, 8, 6, 11, 4, 15, 5, 19, 6, 4, 7, 9, 14, 4, 13, 10, 21, 11, 6, 4, 17, 8, 9, 6, 5, 37, 6, 4, 10, 9, 4, 7, 15, 15, 4, 15, 7, 22, 5, 13, 12, 26, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [4, 32, 11, 5, 4, 45, 16, 10, 17, 28, 4, 25, 13, 8, 20, 9, 4, 19, 8, 37, 4, 46, 16, 18, 24, 12, 4, 8, 27, 5, 13, 4, 6, 11, 5, 4, 15, 7, 57, 22, 4, 14, 8, 21, 26, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + 'attention_mask': [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + } # fmt: on self.tokenizer_integration_test_util(