mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fixed calculation of ctc loss in TFWav2Vec2ForCTC (#18014)
Co-authored-by: Sreyan-G@NVIDIA <sreyang@nvidia.com>
This commit is contained in:
parent
96d833b211
commit
e3139ad301
@ -25,7 +25,13 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
||||
from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
booleans_processing,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -1580,6 +1586,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
||||
"""
|
||||
self.wav2vec2.feature_extractor.trainable = False
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
@ -1702,6 +1709,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
||||
loss = tf.reduce_sum(loss)
|
||||
if self.config.ctc_loss_reduction == "mean":
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
loss = tf.reshape(loss, (1,))
|
||||
else:
|
||||
loss = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user