resolve conflicts for seamless_m4t

This commit is contained in:
Eustache Le Bihan 2025-07-01 17:56:54 +02:00
parent 7eb4de32f9
commit e0f56e4716
7 changed files with 975 additions and 1018 deletions

View File

@ -17,7 +17,7 @@
import copy
import math
import warnings
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
import torch.utils.checkpoint
@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -198,7 +199,7 @@ class BartAttention(nn.Module):
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
@ -270,7 +271,7 @@ class BartAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class BartEncoderLayer(nn.Module):
class BartEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -296,7 +297,7 @@ class BartEncoderLayer(nn.Module):
attention_mask: torch.FloatTensor,
layer_head_mask: torch.FloatTensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -341,7 +342,7 @@ class BartEncoderLayer(nn.Module):
return outputs
class BartDecoderLayer(nn.Module):
class BartDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -385,7 +386,7 @@ class BartDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -782,7 +783,7 @@ class BartEncoder(BartPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
) -> Union[tuple, BaseModelOutput]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -865,7 +866,7 @@ class BartEncoder(BartPreTrainedModel):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
@ -875,21 +876,12 @@ class BartEncoder(BartPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -956,14 +948,14 @@ class BartDecoder(BartPreTrainedModel):
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1129,7 +1121,7 @@ class BartDecoder(BartPreTrainedModel):
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
@ -1137,35 +1129,18 @@ class BartDecoder(BartPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
@ -1253,8 +1228,8 @@ class BartModel(BartPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
encoder_outputs: Optional[list[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@ -1262,7 +1237,7 @@ class BartModel(BartPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Seq2SeqModelOutput]:
) -> Union[tuple, Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1283,7 +1258,7 @@ class BartModel(BartPreTrainedModel):
be used by default.
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
information on the default strategy.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
@ -1425,8 +1400,8 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
encoder_outputs: Optional[list[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -1435,7 +1410,7 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Seq2SeqLMOutput]:
) -> Union[tuple, Seq2SeqLMOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1456,7 +1431,7 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
be used by default.
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
information on the default strategy.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
@ -1612,7 +1587,7 @@ class BartForSequenceClassification(BartPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
encoder_outputs: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -1621,7 +1596,7 @@ class BartForSequenceClassification(BartPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1642,7 +1617,7 @@ class BartForSequenceClassification(BartPreTrainedModel):
be used by default.
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
information on the default strategy.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
@ -1757,7 +1732,7 @@ class BartForQuestionAnswering(BartPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
encoder_outputs: Optional[list[torch.FloatTensor]] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
@ -1767,7 +1742,7 @@ class BartForQuestionAnswering(BartPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1788,7 +1763,7 @@ class BartForQuestionAnswering(BartPreTrainedModel):
be used by default.
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
information on the default strategy.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
@ -1925,7 +1900,7 @@ class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -1933,7 +1908,7 @@ class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
r"""
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

View File

@ -15,7 +15,7 @@
"""PyTorch M2M100 model."""
import math
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
from torch import nn
@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
from ...modeling_flash_attention_utils import (
FlashAttentionKwargs,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -262,7 +263,7 @@ class M2M100Attention(nn.Module):
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
@ -335,7 +336,7 @@ class M2M100Attention(nn.Module):
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
class M2M100EncoderLayer(nn.Module):
class M2M100EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
@ -404,7 +405,7 @@ class M2M100EncoderLayer(nn.Module):
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
class M2M100DecoderLayer(nn.Module):
class M2M100DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
@ -876,28 +877,19 @@ class M2M100Encoder(M2M100PreTrainedModel):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -963,7 +955,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -1135,42 +1127,27 @@ class M2M100Decoder(M2M100PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
@ -1255,8 +1232,8 @@ class M2M100Model(M2M100PreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@ -1264,7 +1241,7 @@ class M2M100Model(M2M100PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1383,8 +1360,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -1393,7 +1370,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.

View File

@ -15,7 +15,7 @@
"""PyTorch NLLB-MoE model."""
import math
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
MoEModelOutput,
MoEModelOutputWithPastAndCrossAttentions,
@ -95,7 +96,7 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
@ -275,7 +276,7 @@ class NllbMoeTop2Router(nn.Module):
router_logits: torch.Tensor,
input_dtype: torch.dtype = torch.float32,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple:
) -> tuple:
"""
Computes the `dispatch_mask` and the `dispatch_weights` for each experts. The masks are adapted to the expert
capacity.
@ -355,7 +356,7 @@ class NllbMoeTop2Router(nn.Module):
return top_1_mask, router_probs
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.LongTensor] = None) -> Tuple:
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.LongTensor] = None) -> tuple:
r"""
The hidden states are reshaped to simplify the computation of the router probabilities (combining weights for
each experts.)
@ -541,14 +542,14 @@ class NllbMoeAttention(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if encoder_hidden_states are provided this layer is used as a cross-attention layer
@ -625,7 +626,7 @@ class NllbMoeAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class NllbMoeEncoderLayer(nn.Module):
class NllbMoeEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
super().__init__()
self.embed_dim = config.d_model
@ -707,7 +708,7 @@ class NllbMoeEncoderLayer(nn.Module):
return outputs
class NllbMoeDecoderLayer(nn.Module):
class NllbMoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
super().__init__()
self.embed_dim = config.d_model
@ -747,7 +748,7 @@ class NllbMoeDecoderLayer(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = True,
@ -1013,27 +1014,18 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
output_router_logits=output_router_logits,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
output_router_logits=output_router_logits,
)
hidden_states = layer_outputs[0]
@ -1136,7 +1128,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -1285,7 +1277,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
@ -1296,37 +1288,18 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.forward,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
)
hidden_states = layer_outputs[0]
@ -1494,8 +1467,8 @@ class NllbMoeModel(NllbMoePreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@ -1503,7 +1476,7 @@ class NllbMoeModel(NllbMoePreTrainedModel):
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqMoEModelOutput]:
) -> Union[tuple[torch.Tensor], Seq2SeqMoEModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1642,8 +1615,8 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin):
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -1652,7 +1625,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqMoEOutput]:
) -> Union[tuple[torch.Tensor], Seq2SeqMoEOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.

View File

@ -22,7 +22,7 @@
import copy
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Optional, Union, tuple
import torch
from torch import Tensor, nn
@ -33,6 +33,7 @@ from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -49,8 +50,14 @@ logger = logging.get_logger(__name__)
@dataclass
class SeamlessM4TGenerationOutput(ModelOutput):
@auto_docstring(
custom_intro="""
Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`],
[`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`].
"""
)
class SeamlessM4TGenerationOutput(ModelOutput):
r"""
Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`],
[`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`].
@ -71,8 +78,8 @@ class SeamlessM4TGenerationOutput(ModelOutput):
waveform: Optional[torch.FloatTensor] = None
waveform_lengths: Optional[torch.IntTensor] = None
sequences: Optional[Tuple[torch.FloatTensor]] = None
unit_sequences: Optional[Tuple[torch.FloatTensor]] = None
sequences: Optional[tuple[torch.FloatTensor]] = None
unit_sequences: Optional[tuple[torch.FloatTensor]] = None
class SeamlessM4TConformerSamePadLayer(nn.Module):
@ -133,7 +140,7 @@ class SeamlessM4TConformerPositionalConvEmbedding(nn.Module):
class SeamlessM4TConformerRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://huggingface.co/papers/2104.09864
"""
def __init__(self, config):
@ -199,7 +206,7 @@ class SeamlessM4TConformerRelPositionalEmbedding(nn.Module):
# Reverse the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
# as in https://huggingface.co/papers/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
@ -341,7 +348,7 @@ class SeamlessM4TConformerSelfAttention(nn.Module):
# linear transformation for positional encoding
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://huggingface.co/papers/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
@ -351,7 +358,7 @@ class SeamlessM4TConformerSelfAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
# self-attention mechanism
batch_size, sequence_length, hidden_size = hidden_states.size()
@ -383,7 +390,7 @@ class SeamlessM4TConformerSelfAttention(nn.Module):
" 'relative'"
)
# apply relative_position_embeddings to qk scores
# as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
# as proposed in Transformer_XL: https://huggingface.co/papers/1901.02860
scores = self._apply_relative_embeddings(
query=query, key=key, relative_position_embeddings=relative_position_embeddings
)
@ -443,7 +450,7 @@ class SeamlessM4TConformerSelfAttention(nn.Module):
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
# 3. attention score: first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://huggingface.co/papers/1901.02860 Section 3.3
# => (batch, head, time1, time2)
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
@ -466,8 +473,8 @@ class SeamlessM4TConformerSelfAttention(nn.Module):
return scores
class SeamlessM4TConformerEncoderLayer(nn.Module):
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
class SeamlessM4TConformerEncoderLayer(GradientCheckpointingLayer):
"""Conformer block based on https://huggingface.co/papers/2005.08100."""
def __init__(self, config):
super().__init__()
@ -482,6 +489,8 @@ class SeamlessM4TConformerEncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
self.self_attn_dropout = nn.Dropout(dropout)
self.self_attn = SeamlessM4TConformerSelfAttention(config)
# Conformer Convolution
self.conv_module = SeamlessM4TConformerConvolutionModule(config)
# Feed-forward 2
@ -588,7 +597,7 @@ class SeamlessM4TConformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = (
@ -596,23 +605,13 @@ class SeamlessM4TConformerEncoder(nn.Module):
)
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
conv_attention_mask,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
@ -925,10 +924,10 @@ class SeamlessM4TAttention(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if encoder_hidden_states are provided this layer is used as a cross-attention layer
@ -967,10 +966,10 @@ class SeamlessM4TAttention(nn.Module):
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
@ -1057,7 +1056,7 @@ class SeamlessM4TFeedForwardNetwork(nn.Module):
return hidden_states
class SeamlessM4TEncoderLayer(nn.Module):
class SeamlessM4TEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None):
super().__init__()
encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim
@ -1120,7 +1119,7 @@ class SeamlessM4TEncoderLayer(nn.Module):
return outputs
class SeamlessM4TDecoderLayer(nn.Module):
class SeamlessM4TDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None):
super().__init__()
decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim
@ -1156,7 +1155,7 @@ class SeamlessM4TDecoderLayer(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
@ -1172,7 +1171,7 @@ class SeamlessM4TDecoderLayer(nn.Module):
encoder_attention_mask (`torch.FloatTensor`):
encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by
very large negative values.
past_key_value (`Tuple(torch.FloatTensor)`):
past_key_value (`tuple(torch.FloatTensor)`):
cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@ -1292,14 +1291,14 @@ class SeamlessM4TPreTrainedModel(PreTrainedModel):
def compute_last_hidden_states_per_sample(
self,
hidden_states: Tuple[Tuple[torch.Tensor]],
hidden_states: tuple[tuple[torch.Tensor]],
beam_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Computes the last hidden states.
Parameters:
hidden_states (`Tuple[Tuple[torch.Tensor]]`):
hidden_states (`tuple[tuple[torch.Tensor]]`):
The generated hidden states. Tuple (one element for each generated token) of tuples (one element for
each layer of the decoder) of torch.FloatTensor of shape (batch_size*num_beams*num_return_sequences,
generated_length, hidden_size).
@ -1373,7 +1372,7 @@ class SeamlessM4TSpeechEncoder(SeamlessM4TPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1488,7 +1487,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutput]:
) -> Union[tuple, BaseModelOutput]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1565,7 +1564,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
@ -1575,19 +1574,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.forward,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1675,13 +1666,13 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1740,7 +1731,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
@ -1750,27 +1741,15 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
@ -1833,15 +1812,15 @@ class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1911,7 +1890,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids
SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING = r"""
SEAMLESS_M4T_COMMON_CUSTOM_ARGS = r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
@ -1934,7 +1913,7 @@ SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING = r"""
be used by default.
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
information on the default strategy.
inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
@ -1998,15 +1977,15 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2014,7 +1993,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -2344,7 +2323,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
def forward(
self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor
) -> Tuple[torch.Tensor]:
) -> tuple[torch.Tensor]:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -2362,7 +2341,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
lang = self.language_embedding(lang_id).transpose(1, 2)
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
dur_out = torch.clamp(torch.round(torch.expm1(log_dur_pred)).long(), min=1)
# B x C x T
if hidden_states.size(0) == 1:
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
@ -2423,66 +2402,6 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
nn.utils.remove_weight_norm(self.hifi_gan.conv_post)
############ WHOLE MODEL related code ################
SEAMLESS_M4T_T2T_START_DOCSTRING = r"""
Generates sequences of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
"""
SEAMLESS_M4T_T2T_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
tgt_lang (`str`, *optional*):
The language to use as target language for translation.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
"""
@auto_docstring(
custom_intro="""
The text-to-text SeamlessM4T Model transformer which can be used for T2TT.
@ -2536,15 +2455,15 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2553,7 +2472,7 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -2634,7 +2553,6 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
encoder_attentions=encoder_outputs.attentions,
)
@auto_docstring(custom_intro=SEAMLESS_M4T_T2T_START_DOCSTRING, custom_args=SEAMLESS_M4T_T2T_INPUTS_DOCSTRING)
def generate(
self,
input_ids=None,
@ -2647,6 +2565,58 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
**kwargs,
):
"""
Generates sequences of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
tgt_lang (`str`, *optional*):
The language to use as target language for translation.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://huggingface.co/papers/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
@ -2705,60 +2675,6 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
return reordered_past
SEAMLESS_M4T_S2T_START_DOCSTRING = r"""
Generates sequences of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
"""
SEAMLESS_M4T_S2T_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
tgt_lang (`str`, *optional*):
The language to use as target language for translation.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
"""
@auto_docstring(
custom_intro="""
The speech-to-text SeamlessM4T Model transformer which can be used for S2TT.
@ -2807,15 +2723,15 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2824,7 +2740,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -2912,7 +2828,6 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
encoder_attentions=encoder_outputs.attentions,
)
@auto_docstring(custom_intro=SEAMLESS_M4T_S2T_START_DOCSTRING, custom_args=SEAMLESS_M4T_S2T_INPUTS_DOCSTRING)
def generate(
self,
input_features=None,
@ -2925,6 +2840,55 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
**kwargs,
):
"""
Generates sequences of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
tgt_lang (`str`, *optional*):
The language to use as target language for translation.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://huggingface.co/papers/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
@ -3083,15 +3047,15 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -3099,7 +3063,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -3241,7 +3205,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
Returns:
`Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`:
`Union[SeamlessM4TGenerationOutput, tuple[Tensor]]`:
- If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`].
- If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size,
sequence_length)` and `waveform_lengths` which gives the length of each sample.
@ -3415,15 +3379,15 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -3432,7 +3396,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -3554,7 +3518,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
return_intermediate_token_ids (`bool`, *optional*):
If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want
@ -3578,7 +3542,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
Returns:
`Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`:
`Union[SeamlessM4TGenerationOutput, tuple[Tensor]]`:
- If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`].
- If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size,
sequence_length)` and `waveform_lengths` which gives the length of each sample.
@ -3784,7 +3748,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -3792,8 +3756,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -3802,7 +3766,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -3970,7 +3934,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
[What are input IDs?](../glossary#input-ids)
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*):
Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
return_intermediate_token_ids (`bool`, *optional*):
If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want
@ -3996,7 +3960,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
other.
Returns:
`Union[SeamlessM4TGenerationOutput, Tuple[Tensor], ModelOutput]`:
`Union[SeamlessM4TGenerationOutput, tuple[Tensor], ModelOutput]`:
- If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`].
- If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of
shape `(batch_size, sequence_length)` and `waveform_lengths` which gives the length of each sample.

View File

@ -17,7 +17,7 @@
import copy
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Optional, Union, tuple
import torch
from torch import Tensor, nn
@ -31,6 +31,7 @@ from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -68,8 +69,14 @@ _CONFIG_FOR_DOC = "SeamlessM4TConfig"
@dataclass
class SeamlessM4TGenerationOutput(ModelOutput):
@auto_docstring(
custom_intro="""
Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`],
[`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`].
"""
)
class SeamlessM4TGenerationOutput(ModelOutput):
r"""
Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`],
[`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`].
@ -90,11 +97,11 @@ class SeamlessM4TGenerationOutput(ModelOutput):
waveform: Optional[torch.FloatTensor] = None
waveform_lengths: Optional[torch.IntTensor] = None
sequences: Optional[Tuple[torch.FloatTensor]] = None
unit_sequences: Optional[Tuple[torch.FloatTensor]] = None
sequences: Optional[tuple[torch.FloatTensor]] = None
unit_sequences: Optional[tuple[torch.FloatTensor]] = None
SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING = r"""
SEAMLESS_M4T_COMMON_CUSTOM_ARGS = r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
@ -117,7 +124,7 @@ SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING = r"""
be used by default.
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
information on the default strategy.
inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
@ -337,13 +344,13 @@ class SeamlessM4TConformerSelfAttention(Wav2Vec2ConformerSelfAttention, nn.Modul
# linear transformation for positional encoding
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://huggingface.co/papers/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
class SeamlessM4TConformerEncoderLayer(Wav2Vec2ConformerEncoderLayer):
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
"""Conformer block based on https://huggingface.co/papers/2005.08100."""
def __init__(self, config):
super().__init__(config)
@ -351,6 +358,8 @@ class SeamlessM4TConformerEncoderLayer(Wav2Vec2ConformerEncoderLayer):
self.ffn1 = SeamlessM4TConformerFeedForward(config)
self.self_attn_dropout = nn.Dropout(dropout)
self.self_attn = SeamlessM4TConformerSelfAttention(config)
# Conformer Convolution
self.conv_module = SeamlessM4TConformerConvolutionModule(config)
self.ffn2 = SeamlessM4TConformerFeedForward(config)
@ -453,7 +462,7 @@ class SeamlessM4TConformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = (
@ -461,23 +470,13 @@ class SeamlessM4TConformerEncoder(nn.Module):
)
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
conv_attention_mask,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
@ -633,10 +632,10 @@ class SeamlessM4TAttention(BartAttention):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if encoder_hidden_states are provided this layer is used as a cross-attention layer
@ -675,10 +674,10 @@ class SeamlessM4TAttention(BartAttention):
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
@ -747,7 +746,7 @@ class SeamlessM4TFeedForwardNetwork(NllbMoeDenseActDense):
self.fc2 = nn.Linear(ffn_dim, config.hidden_size)
class SeamlessM4TEncoderLayer(nn.Module):
class SeamlessM4TEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None):
super().__init__()
encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim
@ -810,7 +809,7 @@ class SeamlessM4TEncoderLayer(nn.Module):
return outputs
class SeamlessM4TDecoderLayer(nn.Module):
class SeamlessM4TDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None):
super().__init__()
decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim
@ -846,7 +845,7 @@ class SeamlessM4TDecoderLayer(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
@ -862,7 +861,7 @@ class SeamlessM4TDecoderLayer(nn.Module):
encoder_attention_mask (`torch.FloatTensor`):
encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by
very large negative values.
past_key_value (`Tuple(torch.FloatTensor)`):
past_key_value (`tuple(torch.FloatTensor)`):
cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@ -982,14 +981,14 @@ class SeamlessM4TPreTrainedModel(PreTrainedModel):
def compute_last_hidden_states_per_sample(
self,
hidden_states: Tuple[Tuple[torch.Tensor]],
hidden_states: tuple[tuple[torch.Tensor]],
beam_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Computes the last hidden states.
Parameters:
hidden_states (`Tuple[Tuple[torch.Tensor]]`):
hidden_states (`tuple[tuple[torch.Tensor]]`):
The generated hidden states. Tuple (one element for each generated token) of tuples (one element for
each layer of the decoder) of torch.FloatTensor of shape (batch_size*num_beams*num_return_sequences,
generated_length, hidden_size).
@ -1063,7 +1062,7 @@ class SeamlessM4TSpeechEncoder(SeamlessM4TPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1178,7 +1177,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutput]:
) -> Union[tuple, BaseModelOutput]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1255,7 +1254,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
@ -1265,19 +1264,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.forward,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1365,13 +1356,13 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1430,7 +1421,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
@ -1440,27 +1431,15 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
@ -1523,15 +1502,15 @@ class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1640,15 +1619,15 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -1656,7 +1635,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -1948,7 +1927,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
def forward(
self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor
) -> Tuple[torch.Tensor]:
) -> tuple[torch.Tensor]:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1966,7 +1945,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
lang = self.language_embedding(lang_id).transpose(1, 2)
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
dur_out = torch.clamp(torch.round(torch.expm1(log_dur_pred)).long(), min=1)
# B x C x T
if hidden_states.size(0) == 1:
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
@ -2140,15 +2119,15 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2157,7 +2136,7 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -2238,7 +2217,6 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
encoder_attentions=encoder_outputs.attentions,
)
@auto_docstring(custom_intro=SEAMLESS_M4T_T2T_START_DOCSTRING, custom_args=SEAMLESS_M4T_T2T_INPUTS_DOCSTRING)
def generate(
self,
input_ids=None,
@ -2251,6 +2229,58 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
**kwargs,
):
"""
Generates sequences of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
tgt_lang (`str`, *optional*):
The language to use as target language for translation.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://huggingface.co/papers/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
@ -2411,15 +2441,15 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2428,7 +2458,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -2516,7 +2546,6 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
encoder_attentions=encoder_outputs.attentions,
)
@auto_docstring(custom_intro=SEAMLESS_M4T_S2T_START_DOCSTRING, custom_args=SEAMLESS_M4T_S2T_INPUTS_DOCSTRING)
def generate(
self,
input_features=None,
@ -2529,6 +2558,55 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin):
**kwargs,
):
"""
Generates sequences of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
tgt_lang (`str`, *optional*):
The language to use as target language for translation.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://huggingface.co/papers/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
@ -2648,15 +2726,15 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2664,7 +2742,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -2806,7 +2884,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
Returns:
`Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`:
`Union[SeamlessM4TGenerationOutput, tuple[Tensor]]`:
- If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`].
- If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size,
sequence_length)` and `waveform_lengths` which gives the length of each sample.
@ -2980,15 +3058,15 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -2997,7 +3075,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -3119,7 +3197,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`):
Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
return_intermediate_token_ids (`bool`, *optional*):
If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want
@ -3143,7 +3221,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin):
Returns:
`Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`:
`Union[SeamlessM4TGenerationOutput, tuple[Tensor]]`:
- If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`].
- If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size,
sequence_length)` and `waveform_lengths` which gives the length of each sample.
@ -3349,7 +3427,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS_DOCSTRING)
@auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -3357,8 +3435,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@ -3367,7 +3445,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -3535,7 +3613,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
[What are input IDs?](../glossary#input-ids)
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*):
Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the
Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the
[`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details.
return_intermediate_token_ids (`bool`, *optional*):
If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want
@ -3561,7 +3639,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin):
other.
Returns:
`Union[SeamlessM4TGenerationOutput, Tuple[Tensor], ModelOutput]`:
`Union[SeamlessM4TGenerationOutput, tuple[Tensor], ModelOutput]`:
- If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`].
- If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of
shape `(batch_size, sequence_length)` and `waveform_lengths` which gives the length of each sample.

View File

@ -1,9 +1,3 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/speecht5/modular_speecht5.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_speecht5.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
#
@ -18,12 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch SpeechT5 model."""
import math
from typing import Optional, Union, list, tuple
from typing import Optional, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
@ -48,371 +44,50 @@ from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig
logger = logging.get_logger(__name__)
class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
_HIDDEN_STATES_START_POSITION = 1
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__()
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
"""
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len = input_ids.size()
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
input_ids.device
)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_input_ids(
self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0
):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# General docstring
class SpeechT5SamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
class SpeechT5PositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class SpeechT5ScaledPositionalEncoding(nn.Module):
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
def __init__(self, dropout, dim, max_len=5000):
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0)
super().__init__()
self.register_buffer("pe", pe, persistent=False)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
def forward(self, emb):
emb = emb + self.alpha * self.pe[:, : emb.size(1)]
emb = self.dropout(emb)
return emb
return shifted_input_ids
class SpeechT5RelativePositionalEncoding(torch.nn.Module):
def __init__(self, dim, max_length=1000):
super().__init__()
self.dim = dim
self.max_length = max_length
self.pe_k = torch.nn.Embedding(2 * max_length, dim)
def shift_spectrograms_right(
input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: Optional[torch.Tensor] = None
):
"""
Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
"""
# thin out frames for reduction factor
if reduction_factor > 1:
input_values = input_values[:, reduction_factor - 1 :: reduction_factor]
if attention_mask is not None:
attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor]
def forward(self, hidden_states):
seq_len = hidden_states.shape[1]
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
shifted_input_values = input_values.new_zeros(input_values.shape)
shifted_input_values[:, 1:] = input_values[:, :-1].clone()
pos_seq[pos_seq < -self.max_length] = -self.max_length
pos_seq[pos_seq >= self.max_length] = self.max_length - 1
pos_seq = pos_seq + self.max_length
# replace possible -100 values in labels by zeros
shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0)
return self.pe_k(pos_seq)
class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class SpeechT5FeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [
SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self._requires_grad and self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class SpeechT5FeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
# non-projected hidden states are needed for quantization
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states, norm_hidden_states
@auto_docstring
class SpeechT5PreTrainedModel(PreTrainedModel):
config_class = SpeechT5Config
base_model_prefix = "speecht5"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = False
_supports_sdpa = False
_supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SpeechT5PositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, SpeechT5FeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
return shifted_input_values, attention_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: tuple[int, int],
mask_prob: float,
@ -532,9 +207,312 @@ def _compute_mask_indices(
return spec_aug_mask
class SpeechT5SpeechEncoderPrenet(SpeechT5PreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5
class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5
class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5
class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->SpeechT5
class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__()
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
"""
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len = input_ids.size()
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
input_ids.device
)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_input_ids(
self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0
):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SpeechT5
class SpeechT5PositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__(config)
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class SpeechT5ScaledPositionalEncoding(nn.Module):
"""
Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895
"""
def __init__(self, dropout, dim, max_len=5000):
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0)
super().__init__()
self.register_buffer("pe", pe, persistent=False)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def forward(self, emb):
emb = emb + self.alpha * self.pe[:, : emb.size(1)]
emb = self.dropout(emb)
return emb
class SpeechT5RelativePositionalEncoding(torch.nn.Module):
def __init__(self, dim, max_length=1000):
super().__init__()
self.dim = dim
self.max_length = max_length
self.pe_k = torch.nn.Embedding(2 * max_length, dim)
def forward(self, hidden_states):
seq_len = hidden_states.shape[1]
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_seq[pos_seq < -self.max_length] = -self.max_length
pos_seq[pos_seq >= self.max_length] = self.max_length - 1
pos_seq = pos_seq + self.max_length
return self.pe_k(pos_seq)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SpeechT5
class SpeechT5SamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SpeechT5
class SpeechT5FeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [
SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self._requires_grad and self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->SpeechT5
class SpeechT5FeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
# non-projected hidden states are needed for quantization
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states, norm_hidden_states
class SpeechT5SpeechEncoderPrenet(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.feature_encoder = SpeechT5FeatureEncoder(config)
self.feature_projection = SpeechT5FeatureProjection(config)
@ -586,6 +564,38 @@ class SpeechT5SpeechEncoderPrenet(SpeechT5PreTrainedModel):
return hidden_states, attention_mask
# Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feature_vector_attention_mask
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
# Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feat_extract_output_lengths
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
@ -910,10 +920,10 @@ class SpeechT5Attention(nn.Module):
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
@ -1126,7 +1136,7 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer):
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`tuple(torch.FloatTensor)`): cached past key and value projection states
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@ -1186,6 +1196,44 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer):
return outputs
@auto_docstring
class SpeechT5PreTrainedModel(PreTrainedModel):
config_class = SpeechT5Config
base_model_prefix = "speecht5"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SpeechT5PositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, SpeechT5FeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class SpeechT5Encoder(SpeechT5PreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`].
@ -1490,7 +1538,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
@ -2040,22 +2088,6 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
)
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
@auto_docstring(
custom_intro="""
SpeechT5 Model with a speech encoder and a text decoder.
@ -2249,27 +2281,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
return reordered_past
def shift_spectrograms_right(
input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: Optional[torch.Tensor] = None
):
"""
Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
"""
# thin out frames for reduction factor
if reduction_factor > 1:
input_values = input_values[:, reduction_factor - 1 :: reduction_factor]
if attention_mask is not None:
attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor]
shifted_input_values = input_values.new_zeros(input_values.shape)
shifted_input_values[:, 1:] = input_values[:, :-1].clone()
# replace possible -100 values in labels by zeros
shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0)
return shifted_input_values, attention_mask
def _generate_speech(
model: SpeechT5PreTrainedModel,
input_values: torch.FloatTensor,

View File

@ -7,7 +7,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Optional, Union
import numpy as np
import torch
@ -17,6 +17,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@ -35,43 +36,36 @@ from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
@dataclass
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
"""
@auto_docstring(
custom_intro="""
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
Args:
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
"""
)
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
r"""
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
The perplexity of the codevector distribution, used to measure the diversity of the codebook.
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
"""
loss: Optional[torch.FloatTensor] = None
projected_states: Optional[torch.FloatTensor] = None
projected_quantized_states: Optional[torch.FloatTensor] = None
codevector_perplexity: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
contrastive_loss: Optional[torch.FloatTensor] = None
diversity_loss: Optional[torch.FloatTensor] = None
@ -134,7 +128,7 @@ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://huggingface.co/papers/2104.09864
"""
def __init__(self, config):
@ -201,7 +195,7 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
# Reverse the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
# as in https://huggingface.co/papers/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
@ -216,7 +210,7 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
return relative_position_embeddings
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
class Wav2Vec2ConformerNoLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -237,7 +231,7 @@ class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
return hidden_states
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
class Wav2Vec2ConformerLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -264,7 +258,7 @@ class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
return hidden_states
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
class Wav2Vec2ConformerGroupNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
@ -324,13 +318,7 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
hidden_states = conv_layer(hidden_states)
return hidden_states
@ -457,7 +445,7 @@ class Wav2Vec2ConformerSelfAttention(nn.Module):
# linear transformation for positional encoding
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://huggingface.co/papers/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
@ -467,7 +455,7 @@ class Wav2Vec2ConformerSelfAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
# self-attention mechanism
batch_size, sequence_length, hidden_size = hidden_states.size()
@ -499,7 +487,7 @@ class Wav2Vec2ConformerSelfAttention(nn.Module):
" 'relative'"
)
# apply relative_position_embeddings to qk scores
# as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
# as proposed in Transformer_XL: https://huggingface.co/papers/1901.02860
scores = self._apply_relative_embeddings(
query=query, key=key, relative_position_embeddings=relative_position_embeddings
)
@ -559,7 +547,7 @@ class Wav2Vec2ConformerSelfAttention(nn.Module):
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
# 3. attention score: first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://huggingface.co/papers/1901.02860 Section 3.3
# => (batch, head, time1, time2)
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
@ -582,8 +570,8 @@ class Wav2Vec2ConformerSelfAttention(nn.Module):
return scores
class Wav2Vec2ConformerEncoderLayer(nn.Module):
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
class Wav2Vec2ConformerEncoderLayer(GradientCheckpointingLayer):
"""Conformer block based on https://huggingface.co/papers/2005.08100."""
def __init__(self, config):
super().__init__()
@ -703,27 +691,18 @@ class Wav2Vec2ConformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
@ -748,7 +727,7 @@ class Wav2Vec2ConformerEncoder(nn.Module):
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
"""
def __init__(self, config):
@ -966,7 +945,7 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
def _compute_mask_indices(
shape: Tuple[int, int],
shape: tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
@ -974,7 +953,7 @@ def _compute_mask_indices(
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
@ -1121,7 +1100,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
[SpecAugment](https://huggingface.co/papers/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
@ -1168,7 +1147,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2ConformerBaseModelOutput]:
) -> Union[tuple, Wav2Vec2ConformerBaseModelOutput]:
r"""
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
@ -1281,7 +1260,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
) -> Union[tuple, Wav2Vec2ConformerForPreTrainingOutput]:
r"""
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
@ -1386,7 +1365,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
).permute(2, 0, 1, 3)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
# of equation (3) in https://huggingface.co/papers/2006.11477
logits = self.compute_contrastive_logits(
quantized_features[None, :],
negative_quantized_features,
@ -1485,7 +1464,7 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
) -> Union[tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
@ -1596,11 +1575,11 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
) -> Union[tuple, SequenceClassifierOutput]:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -1698,11 +1677,11 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
) -> Union[tuple, TokenClassifierOutput]:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -1751,7 +1730,7 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
super().__init__()
self.scale = scale
self.margin = margin
self.num_labels = num_labels
@ -1868,11 +1847,11 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, XVectorOutput]:
) -> Union[tuple, XVectorOutput]:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):