fixed calculation of ctc loss in TFWav2Vec2ForCTC (#18014)

Co-authored-by: Sreyan-G@NVIDIA <sreyang@nvidia.com>
This commit is contained in:
Sreyan Ghosh 2022-07-04 22:06:36 +05:30 committed by GitHub
parent 96d833b211
commit e3139ad301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,7 +25,13 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...modeling_tf_utils import (
TFPreTrainedModel,
booleans_processing,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
@ -1580,6 +1586,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
"""
self.wav2vec2.feature_extractor.trainable = False
@unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
@ -1702,6 +1709,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
loss = tf.reduce_sum(loss)
if self.config.ctc_loss_reduction == "mean":
loss = tf.reduce_mean(loss)
loss = tf.reshape(loss, (1,))
else:
loss = None