Replace input_values_processing with unpack_inputs (#21502)

* Replace input_values_prrocessing with unpack_inputs

* Skip test failing with OOM

* Update tests
This commit is contained in:
amyeroberts 2023-02-10 18:19:39 +00:00 committed by GitHub
parent 557125637d
commit cb56590111
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 427 deletions

View File

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" TensorFlow Hubert model.""" """ TensorFlow Hubert model."""
import inspect
import warnings import warnings
from collections.abc import Mapping
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
@ -23,10 +21,14 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...modeling_tf_utils import (
TFPreTrainedModel,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
@ -47,124 +49,6 @@ TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
LARGE_NEGATIVE = -1e8 LARGE_NEGATIVE = -1e8
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.input_values_processing
def input_values_processing(func, config, input_values, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
training.
Args:
func (`callable`):
The callable function of the TensorFlow model.
config ([`PretrainedConfig`]):
The config of the running model.
**kwargs:
The inputs of the model.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
if isinstance(input_values, (tuple, list)):
for i, input in enumerate(input_values):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern `name:id` then we check only the
# `name` part
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
output[parameter_names[i]] = input
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[i]}."
)
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
" instead.",
FutureWarning,
)
output["input_values"] = input_values.pop("inputs")
if "decoder_cached_states" in input_values:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
for k, v in dict(input_values).items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
elif k not in parameter_names and "args" not in parameter_names:
logger.warning(
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
)
continue
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(input_values, tf.Tensor) or input_values is None:
output[parameter_names[0]] = input_values
else:
raise ValueError(
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_values`
output["input_values"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
output.update(booleans_processing(config=config, **boolean_dict))
return output
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
def _sample_without_replacement(distribution, num_samples): def _sample_without_replacement(distribution, num_samples):
""" """
@ -1208,6 +1092,7 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
return hidden_states return hidden_states
@unpack_inputs
def call( def call(
self, self,
input_values: tf.Tensor, input_values: tf.Tensor,
@ -1222,51 +1107,33 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs: Any, **kwargs: Any,
): ):
inputs = input_values_processing( hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
func=self.call,
config=self.config,
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
hidden_states = self.feature_extractor( if attention_mask is not None:
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
)
if inputs["attention_mask"] is not None:
# compute real output lengths according to convolution formula # compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
attention_mask = tf.sequence_mask( attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
) )
hidden_states = self.feature_projection(hidden_states, training=inputs["training"]) hidden_states = self.feature_projection(hidden_states, training=training)
mask_time_indices = kwargs.get("mask_time_indices", None) mask_time_indices = kwargs.get("mask_time_indices", None)
if inputs["training"]: if training:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
if not inputs["return_dict"]: if not return_dict:
return (hidden_states,) + encoder_outputs[1:] return (hidden_states,) + encoder_outputs[1:]
return TFBaseModelOutput( return TFBaseModelOutput(
@ -1428,6 +1295,7 @@ class TFHubertModel(TFHubertPreTrainedModel):
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call( def call(
self, self,
input_values: tf.Tensor, input_values: tf.Tensor,
@ -1469,9 +1337,11 @@ class TFHubertModel(TFHubertPreTrainedModel):
>>> hidden_states = model(input_values).last_hidden_state >>> hidden_states = model(input_values).last_hidden_state
```""" ```"""
inputs = input_values_processing( output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
func=self.call, output_attentions = output_attentions if output_attentions else self.config.output_attentions
config=self.config, return_dict = return_dict if return_dict else self.config.return_dict
outputs = self.hubert(
input_values=input_values, input_values=input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -1484,27 +1354,6 @@ class TFHubertModel(TFHubertPreTrainedModel):
training=training, training=training,
) )
inputs["output_hidden_states"] = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
)
inputs["output_attentions"] = (
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
)
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
outputs = self.hubert(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
def serving_output(self, output): def serving_output(self, output):
@ -1548,6 +1397,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call( def call(
self, self,
input_values: tf.Tensor, input_values: tf.Tensor,
@ -1605,9 +1455,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
```""" ```"""
inputs = input_values_processing(
func=self.call, outputs = self.hubert(
config=self.config,
input_values=input_values, input_values=input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -1619,21 +1468,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
outputs = self.hubert(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = outputs[0] hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
@ -1642,9 +1478,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
attention_mask = ( attention_mask = (
inputs["attention_mask"] attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
if inputs["attention_mask"] is not None
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
) )
input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
@ -1671,7 +1505,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
else: else:
loss = None loss = None
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output

View File

@ -14,9 +14,7 @@
# limitations under the License. # limitations under the License.
""" TensorFlow Wav2Vec2 model.""" """ TensorFlow Wav2Vec2 model."""
import inspect
import warnings import warnings
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -27,7 +25,6 @@ from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
booleans_processing,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
@ -91,123 +88,6 @@ class TFWav2Vec2BaseModelOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
def input_values_processing(func, config, input_values, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
training.
Args:
func (`callable`):
The callable function of the TensorFlow model.
config ([`PretrainedConfig`]):
The config of the running model.
**kwargs:
The inputs of the model.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
if isinstance(input_values, (tuple, list)):
for i, input in enumerate(input_values):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern `name:id` then we check only the
# `name` part
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
output[parameter_names[i]] = input
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[i]}."
)
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
" instead.",
FutureWarning,
)
output["input_values"] = input_values.pop("inputs")
if "decoder_cached_states" in input_values:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
for k, v in dict(input_values).items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
elif k not in parameter_names and "args" not in parameter_names:
logger.warning(
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
)
continue
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(input_values, tf.Tensor) or input_values is None:
output[parameter_names[0]] = input_values
else:
raise ValueError(
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_values`
output["input_values"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
output.update(booleans_processing(config=config, **boolean_dict))
return output
def _sample_without_replacement(distribution, num_samples): def _sample_without_replacement(distribution, num_samples):
""" """
Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
@ -1238,6 +1118,7 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
return hidden_states return hidden_states
@unpack_inputs
def call( def call(
self, self,
input_values: tf.Tensor, input_values: tf.Tensor,
@ -1252,52 +1133,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs: Any, **kwargs: Any,
): ):
inputs = input_values_processing( extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
func=self.call,
config=self.config,
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
extract_features = self.feature_extractor(
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
)
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1)) # extract_features = tf.transpose(extract_features, perm=(0, 2, 1))
if inputs["attention_mask"] is not None: if attention_mask is not None:
# compute real output lengths according to convolution formula # compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
attention_mask = tf.sequence_mask( attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
) )
hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"]) hidden_states, extract_features = self.feature_projection(extract_features, training=training)
mask_time_indices = kwargs.get("mask_time_indices", None) mask_time_indices = kwargs.get("mask_time_indices", None)
if inputs["training"]: if training:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
if not inputs["return_dict"]: if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:] return (hidden_states, extract_features) + encoder_outputs[1:]
return TFWav2Vec2BaseModelOutput( return TFWav2Vec2BaseModelOutput(
@ -1460,6 +1323,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call( def call(
self, self,
input_values: tf.Tensor, input_values: tf.Tensor,
@ -1501,9 +1365,11 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
>>> hidden_states = model(input_values).last_hidden_state >>> hidden_states = model(input_values).last_hidden_state
```""" ```"""
inputs = input_values_processing( output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
func=self.call, output_attentions = output_attentions if output_attentions else self.config.output_attentions
config=self.config, return_dict = return_dict if return_dict else self.config.return_dict
outputs = self.wav2vec2(
input_values=input_values, input_values=input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -1516,27 +1382,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
training=training, training=training,
) )
inputs["output_hidden_states"] = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
)
inputs["output_attentions"] = (
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
)
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
outputs = self.wav2vec2(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
def serving_output(self, output): def serving_output(self, output):
@ -1642,9 +1487,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
```""" ```"""
inputs = input_values_processing(
func=self.call, outputs = self.wav2vec2(
config=self.config,
input_values=input_values, input_values=input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -1656,21 +1500,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
outputs = self.wav2vec2(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = outputs[0] hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
@ -1679,9 +1510,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
attention_mask = ( attention_mask = (
inputs["attention_mask"] attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
if inputs["attention_mask"] is not None
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
) )
input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
@ -1708,7 +1537,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
else: else:
loss = None loss = None
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output

View File

@ -304,18 +304,15 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs) self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Hubert has no inputs_embeds @unittest.skip(reason="Hubert has no input embeddings")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
# Hubert cannot resize token embeddings @unittest.skip(reason="Hubert has no tokens embeddings")
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
# Hubert has no inputs_embeds @unittest.skip(reason="Hubert has no input embeddings")
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@ -324,10 +321,6 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960") model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip("Loss shapes for CTC don't match the base test.")
def test_loss_computation(self):
pass
@require_tf @require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
@ -426,29 +419,36 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs) self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Hubert has no inputs_embeds @unittest.skip(reason="Hubert has no input embeddings")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
# Hubert cannot resize token embeddings @unittest.skip(reason="Hubert has no tokens embeddings")
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
# Hubert has no inputs_embeds @unittest.skip(reason="Hubert has no input embeddings or get_input_embeddings method")
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft") model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip("Loss shapes for CTC don't match the base test.") # We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_loss_computation(self): def test_keras_fit(self):
pass default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size
@require_tf @require_tf

View File

@ -369,18 +369,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs) self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds @unittest.skip(reason="Wav2Vec2 has no input embeddings")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
# Wav2Vec2 cannot resize token embeddings @unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
# Wav2Vec2 has no inputs_embeds @unittest.skip(reason="Wav2Vec2 has no input embeddings")
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@ -389,13 +386,19 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!") # We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self): def test_dataset_conversion(self):
pass default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@unittest.skip(reason="Training goes OOM and crashes with the default options!") # We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self): def test_keras_fit(self):
pass default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@require_tf @require_tf
@ -497,18 +500,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs) self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds @unittest.skip(reason="Wav2Vec2 has no input embeddings")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
# Wav2Vec2 cannot resize token embeddings @unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
# Wav2Vec2 has no inputs_embeds @unittest.skip(reason="Wav2Vec2 has no input embeddings")
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@ -517,13 +517,19 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!") # We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self): def test_dataset_conversion(self):
pass default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@unittest.skip(reason="Training goes OOM and crashes with the default options!") # We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self): def test_keras_fit(self):
pass default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
@require_tf @require_tf