mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Speech] Refactor Examples (#14040)
* adapt_examples * up * up * up * up * add auto models * finish
This commit is contained in:
parent
2024faf171
commit
d5ff69fce9
@ -59,3 +59,9 @@ SEWForCTC
|
||||
.. autoclass:: transformers.SEWForCTC
|
||||
:members: forward
|
||||
|
||||
|
||||
SEWForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.SEWForSequenceClassification
|
||||
:members: forward
|
||||
|
@ -59,3 +59,8 @@ SEWDForCTC
|
||||
.. autoclass:: transformers.SEWDForCTC
|
||||
:members: forward
|
||||
|
||||
SEWDForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.SEWDForSequenceClassification
|
||||
:members: forward
|
||||
|
@ -1143,6 +1143,7 @@ if is_torch_available():
|
||||
[
|
||||
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"SEWForCTC",
|
||||
"SEWForSequenceClassification",
|
||||
"SEWModel",
|
||||
"SEWPreTrainedModel",
|
||||
]
|
||||
@ -1151,6 +1152,7 @@ if is_torch_available():
|
||||
[
|
||||
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"SEWDForCTC",
|
||||
"SEWDForSequenceClassification",
|
||||
"SEWDModel",
|
||||
"SEWDPreTrainedModel",
|
||||
]
|
||||
@ -2858,8 +2860,20 @@ if TYPE_CHECKING:
|
||||
RoFormerPreTrainedModel,
|
||||
load_tf_weights_in_roformer,
|
||||
)
|
||||
from .models.sew import SEW_PRETRAINED_MODEL_ARCHIVE_LIST, SEWForCTC, SEWModel, SEWPreTrainedModel
|
||||
from .models.sew_d import SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST, SEWDForCTC, SEWDModel, SEWDPreTrainedModel
|
||||
from .models.sew import (
|
||||
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SEWForCTC,
|
||||
SEWForSequenceClassification,
|
||||
SEWModel,
|
||||
SEWPreTrainedModel,
|
||||
)
|
||||
from .models.sew_d import (
|
||||
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SEWDForCTC,
|
||||
SEWDForSequenceClassification,
|
||||
SEWDModel,
|
||||
SEWDPreTrainedModel,
|
||||
)
|
||||
from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
|
||||
from .models.speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -476,6 +476,8 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
||||
("hubert", "HubertForSequenceClassification"),
|
||||
("sew", "SEWForSequenceClassification"),
|
||||
("sew-d", "SEWDForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -25,7 +25,12 @@ from torch.nn import CrossEntropyLoss
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
@ -36,6 +41,13 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "HubertConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-base-ls960"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = ("superb/hubert-base-superb-ks",)
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/hubert-base-ls960",
|
||||
@ -999,6 +1011,7 @@ class HubertModel(HubertPreTrainedModel):
|
||||
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||
HUBERT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
||||
class HubertForCTC(HubertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1025,7 +1038,12 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
self.hubert.feature_extractor._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -1041,41 +1059,6 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2Processor, HubertForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||
>>> model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
... labels = processor(target_transcription, return_tensors="pt").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1126,7 +1109,7 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
@ -1141,8 +1124,8 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
""",
|
||||
HUBERT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
||||
class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1155,7 +1138,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
@ -1163,7 +1145,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
"""
|
||||
self.hubert.feature_extractor._freeze_parameters()
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
@ -1173,7 +1154,13 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -1188,29 +1175,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
|
||||
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
|
||||
|
||||
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
|
||||
|
||||
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_label = "down"
|
||||
>>> labels = torch.tensor([model.config.label2id[target_label]])
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1225,7 +1189,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[1]
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
@ -1248,7 +1212,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
|
@ -28,6 +28,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_sew"] = [
|
||||
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"SEWForCTC",
|
||||
"SEWForSequenceClassification",
|
||||
"SEWModel",
|
||||
"SEWPreTrainedModel",
|
||||
]
|
||||
@ -36,7 +37,13 @@ if TYPE_CHECKING:
|
||||
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_sew import SEW_PRETRAINED_MODEL_ARCHIVE_LIST, SEWForCTC, SEWModel, SEWPreTrainedModel
|
||||
from .modeling_sew import (
|
||||
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SEWForCTC,
|
||||
SEWForSequenceClassification,
|
||||
SEWModel,
|
||||
SEWPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
@ -113,6 +113,11 @@ class SEWConfig(PretrainedConfig):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.SEWForCTC`.
|
||||
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
||||
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
|
||||
Example::
|
||||
|
||||
@ -161,6 +166,8 @@ class SEWConfig(PretrainedConfig):
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
@ -214,3 +221,7 @@ class SEWConfig(PretrainedConfig):
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
# sequence classification
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
@ -21,12 +21,18 @@ import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_sew import SEWConfig
|
||||
@ -36,6 +42,13 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "SEWConfig"
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
|
||||
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"asapp/sew-tiny-100k",
|
||||
@ -900,6 +913,7 @@ class SEWModel(SEWPreTrainedModel):
|
||||
"""SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||
SEW_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
||||
class SEWForCTC(SEWPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -926,7 +940,12 @@ class SEWForCTC(SEWPreTrainedModel):
|
||||
self.sew.feature_extractor._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -942,41 +961,6 @@ class SEWForCTC(SEWPreTrainedModel):
|
||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2Processor, SEWForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k")
|
||||
>>> model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
... labels = processor(target_transcription, return_tensors="pt").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1027,9 +1011,115 @@ class SEWForCTC(SEWPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
|
||||
Keyword Spotting.
|
||||
""",
|
||||
SEW_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
||||
class SEWForSequenceClassification(SEWPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.sew = SEWModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||||
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.sew.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.sew.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.sew(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
if attention_mask is None:
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_sew_d"] = [
|
||||
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"SEWDForCTC",
|
||||
"SEWDForSequenceClassification",
|
||||
"SEWDModel",
|
||||
"SEWDPreTrainedModel",
|
||||
]
|
||||
@ -36,7 +37,13 @@ if TYPE_CHECKING:
|
||||
from .configuration_sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_sew_d import SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST, SEWDForCTC, SEWDModel, SEWDPreTrainedModel
|
||||
from .modeling_sew_d import (
|
||||
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SEWDForCTC,
|
||||
SEWDForSequenceClassification,
|
||||
SEWDModel,
|
||||
SEWDPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
@ -131,6 +131,11 @@ class SEWDConfig(PretrainedConfig):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.SEWDForCTC`.
|
||||
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
||||
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
|
||||
Example::
|
||||
|
||||
@ -186,6 +191,8 @@ class SEWDConfig(PretrainedConfig):
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
@ -246,3 +253,7 @@ class SEWDConfig(PretrainedConfig):
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
# sequence classification
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
@ -22,13 +22,18 @@ import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import _softmax_backward_data, nn
|
||||
from torch.nn import LayerNorm
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_sew_d import SEWDConfig
|
||||
@ -38,6 +43,12 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"asapp/sew-d-tiny-100k",
|
||||
@ -1405,6 +1416,7 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
"""SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||
SEWD_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
||||
class SEWDForCTC(SEWDPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1431,7 +1443,12 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
||||
self.sew_d.feature_extractor._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -1447,41 +1464,6 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2Processor, SEWDForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k")
|
||||
>>> model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
... labels = processor(target_transcription, return_tensors="pt").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1532,9 +1514,115 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
SEWD Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
|
||||
Keyword Spotting.
|
||||
""",
|
||||
SEWD_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
||||
class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.sew_d = SEWDModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||||
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.sew_d.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.sew_d.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.sew_d(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
if attention_mask is None:
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -46,6 +46,12 @@ _CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = ("superb/wav2vec2-base-superb-ks",)
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/wav2vec2-base-960h",
|
||||
"facebook/wav2vec2-large-960h",
|
||||
@ -1557,7 +1563,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
@ -1602,8 +1608,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class="Wav2Vec2FeatureExtractor",
|
||||
checkpoint="superb/wav2vec2-base-superb-ks",
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
@ -1636,7 +1642,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[2]
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
@ -1659,7 +1665,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
|
@ -3289,6 +3289,15 @@ class SEWForCTC:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SEWForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SEWModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -3315,6 +3324,15 @@ class SEWDForCTC:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SEWDForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SEWDModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SEWForCTC, SEWModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers import (
|
||||
SEWForCTC,
|
||||
SEWForSequenceClassification,
|
||||
SEWModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
|
||||
@ -219,6 +225,54 @@ class SEWModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = SEWForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = SEWForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = SEWForCTC(config)
|
||||
model.to(torch_device)
|
||||
@ -241,7 +295,7 @@ class SEWModelTester:
|
||||
|
||||
@require_torch
|
||||
class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SEWForCTC, SEWModel) if is_torch_available() else ()
|
||||
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@ -328,6 +382,14 @@ class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SEWDForCTC, SEWDModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers import (
|
||||
SEWDForCTC,
|
||||
SEWDForSequenceClassification,
|
||||
SEWDModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
|
||||
@ -240,6 +246,54 @@ class SEWDModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = SEWDForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = SEWDForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = SEWDForCTC(config)
|
||||
model.to(torch_device)
|
||||
@ -262,7 +316,7 @@ class SEWDModelTester:
|
||||
|
||||
@require_torch
|
||||
class SEWDModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SEWDForCTC, SEWDModel) if is_torch_available() else ()
|
||||
all_model_classes = (SEWDForCTC, SEWDModel, SEWDForSequenceClassification) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
|
Loading…
Reference in New Issue
Block a user