mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Replace input_values_processing with unpack_inputs (#21502)
* Replace input_values_prrocessing with unpack_inputs * Skip test failing with OOM * Update tests
This commit is contained in:
parent
557125637d
commit
cb56590111
@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TensorFlow Hubert model."""
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -23,10 +21,14 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
||||
from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
@ -47,124 +49,6 @@ TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
LARGE_NEGATIVE = -1e8
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.input_values_processing
|
||||
def input_values_processing(func, config, input_values, **kwargs):
|
||||
"""
|
||||
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
|
||||
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
|
||||
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
|
||||
training.
|
||||
|
||||
Args:
|
||||
func (`callable`):
|
||||
The callable function of the TensorFlow model.
|
||||
config ([`PretrainedConfig`]):
|
||||
The config of the running model.
|
||||
**kwargs:
|
||||
The inputs of the model.
|
||||
|
||||
Returns:
|
||||
Two lists, one for the missing layers, and another one for the unexpected layers.
|
||||
"""
|
||||
signature = dict(inspect.signature(func).parameters)
|
||||
signature.pop("kwargs", None)
|
||||
signature.pop("self", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
|
||||
if isinstance(input_values, (tuple, list)):
|
||||
for i, input in enumerate(input_values):
|
||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
||||
if type(input) == tf.Tensor:
|
||||
# Tensor names have always the pattern `name:id` then we check only the
|
||||
# `name` part
|
||||
tensor_name = input.name.split(":")[0]
|
||||
|
||||
if tensor_name in parameter_names:
|
||||
output[tensor_name] = input
|
||||
else:
|
||||
output[parameter_names[i]] = input
|
||||
elif isinstance(input, allowed_types) or input is None:
|
||||
output[parameter_names[i]] = input
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
|
||||
f" {parameter_names[i]}."
|
||||
)
|
||||
elif isinstance(input_values, Mapping):
|
||||
if "inputs" in input_values:
|
||||
warnings.warn(
|
||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
|
||||
" instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
output["input_values"] = input_values.pop("inputs")
|
||||
|
||||
if "decoder_cached_states" in input_values:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
|
||||
" `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
output["past_key_values"] = input_values.pop("decoder_cached_states")
|
||||
|
||||
for k, v in dict(input_values).items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
elif k not in parameter_names and "args" not in parameter_names:
|
||||
logger.warning(
|
||||
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
else:
|
||||
if isinstance(input_values, tf.Tensor) or input_values is None:
|
||||
output[parameter_names[0]] = input_values
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
|
||||
f" {parameter_names[0]}."
|
||||
)
|
||||
|
||||
for name in parameter_names:
|
||||
if name not in list(output.keys()) and name != "args":
|
||||
output[name] = kwargs.pop(name, signature[name].default)
|
||||
|
||||
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
|
||||
# So to respect the proper output we have to add this exception
|
||||
if "args" in output:
|
||||
if output["args"] is not None and type(output["args"]) == tf.Tensor:
|
||||
tensor_name = output["args"].name.split(":")[0]
|
||||
output[tensor_name] = output["args"]
|
||||
else:
|
||||
# `args` in this case is always the first parameter, then `input_values`
|
||||
output["input_values"] = output["args"]
|
||||
|
||||
del output["args"]
|
||||
|
||||
if "kwargs" in output:
|
||||
del output["kwargs"]
|
||||
|
||||
boolean_dict = {
|
||||
k: v
|
||||
for k, v in output.items()
|
||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
||||
}
|
||||
|
||||
output.update(booleans_processing(config=config, **boolean_dict))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
|
||||
def _sample_without_replacement(distribution, num_samples):
|
||||
"""
|
||||
@ -1208,6 +1092,7 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_values: tf.Tensor,
|
||||
@ -1222,51 +1107,33 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
|
||||
training: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
inputs = input_values_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_values=input_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
|
||||
|
||||
hidden_states = self.feature_extractor(
|
||||
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
|
||||
|
||||
attention_mask = tf.sequence_mask(
|
||||
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.feature_projection(hidden_states, training=training)
|
||||
|
||||
mask_time_indices = kwargs.get("mask_time_indices", None)
|
||||
if inputs["training"]:
|
||||
if training:
|
||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return (hidden_states,) + encoder_outputs[1:]
|
||||
|
||||
return TFBaseModelOutput(
|
||||
@ -1428,6 +1295,7 @@ class TFHubertModel(TFHubertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_values: tf.Tensor,
|
||||
@ -1469,9 +1337,11 @@ class TFHubertModel(TFHubertPreTrainedModel):
|
||||
>>> hidden_states = model(input_values).last_hidden_state
|
||||
```"""
|
||||
|
||||
inputs = input_values_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||
return_dict = return_dict if return_dict else self.config.return_dict
|
||||
|
||||
outputs = self.hubert(
|
||||
input_values=input_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -1484,27 +1354,6 @@ class TFHubertModel(TFHubertPreTrainedModel):
|
||||
training=training,
|
||||
)
|
||||
|
||||
inputs["output_hidden_states"] = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
|
||||
)
|
||||
inputs["output_attentions"] = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
|
||||
)
|
||||
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
|
||||
|
||||
outputs = self.hubert(
|
||||
input_values=inputs["input_values"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output):
|
||||
@ -1548,6 +1397,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_values: tf.Tensor,
|
||||
@ -1605,9 +1455,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
```"""
|
||||
inputs = input_values_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
|
||||
outputs = self.hubert(
|
||||
input_values=input_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -1619,21 +1468,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
outputs = self.hubert(
|
||||
input_values=inputs["input_values"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
@ -1642,9 +1478,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
||||
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||
|
||||
attention_mask = (
|
||||
inputs["attention_mask"]
|
||||
if inputs["attention_mask"] is not None
|
||||
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
|
||||
attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
|
||||
)
|
||||
input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
|
||||
|
||||
@ -1671,7 +1505,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
||||
else:
|
||||
loss = None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
@ -14,9 +14,7 @@
|
||||
# limitations under the License.
|
||||
""" TensorFlow Wav2Vec2 model."""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
@ -27,7 +25,6 @@ from ...activations_tf import get_tf_activation
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||
from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
booleans_processing,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
@ -91,123 +88,6 @@ class TFWav2Vec2BaseModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
def input_values_processing(func, config, input_values, **kwargs):
|
||||
"""
|
||||
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
|
||||
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
|
||||
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
|
||||
training.
|
||||
|
||||
Args:
|
||||
func (`callable`):
|
||||
The callable function of the TensorFlow model.
|
||||
config ([`PretrainedConfig`]):
|
||||
The config of the running model.
|
||||
**kwargs:
|
||||
The inputs of the model.
|
||||
|
||||
Returns:
|
||||
Two lists, one for the missing layers, and another one for the unexpected layers.
|
||||
"""
|
||||
signature = dict(inspect.signature(func).parameters)
|
||||
signature.pop("kwargs", None)
|
||||
signature.pop("self", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
|
||||
if isinstance(input_values, (tuple, list)):
|
||||
for i, input in enumerate(input_values):
|
||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
||||
if type(input) == tf.Tensor:
|
||||
# Tensor names have always the pattern `name:id` then we check only the
|
||||
# `name` part
|
||||
tensor_name = input.name.split(":")[0]
|
||||
|
||||
if tensor_name in parameter_names:
|
||||
output[tensor_name] = input
|
||||
else:
|
||||
output[parameter_names[i]] = input
|
||||
elif isinstance(input, allowed_types) or input is None:
|
||||
output[parameter_names[i]] = input
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
|
||||
f" {parameter_names[i]}."
|
||||
)
|
||||
elif isinstance(input_values, Mapping):
|
||||
if "inputs" in input_values:
|
||||
warnings.warn(
|
||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
|
||||
" instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
output["input_values"] = input_values.pop("inputs")
|
||||
|
||||
if "decoder_cached_states" in input_values:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
|
||||
" `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
output["past_key_values"] = input_values.pop("decoder_cached_states")
|
||||
|
||||
for k, v in dict(input_values).items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
elif k not in parameter_names and "args" not in parameter_names:
|
||||
logger.warning(
|
||||
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
else:
|
||||
if isinstance(input_values, tf.Tensor) or input_values is None:
|
||||
output[parameter_names[0]] = input_values
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
|
||||
f" {parameter_names[0]}."
|
||||
)
|
||||
|
||||
for name in parameter_names:
|
||||
if name not in list(output.keys()) and name != "args":
|
||||
output[name] = kwargs.pop(name, signature[name].default)
|
||||
|
||||
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
|
||||
# So to respect the proper output we have to add this exception
|
||||
if "args" in output:
|
||||
if output["args"] is not None and type(output["args"]) == tf.Tensor:
|
||||
tensor_name = output["args"].name.split(":")[0]
|
||||
output[tensor_name] = output["args"]
|
||||
else:
|
||||
# `args` in this case is always the first parameter, then `input_values`
|
||||
output["input_values"] = output["args"]
|
||||
|
||||
del output["args"]
|
||||
|
||||
if "kwargs" in output:
|
||||
del output["kwargs"]
|
||||
|
||||
boolean_dict = {
|
||||
k: v
|
||||
for k, v in output.items()
|
||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
||||
}
|
||||
|
||||
output.update(booleans_processing(config=config, **boolean_dict))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _sample_without_replacement(distribution, num_samples):
|
||||
"""
|
||||
Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
|
||||
@ -1238,6 +1118,7 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_values: tf.Tensor,
|
||||
@ -1252,52 +1133,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
|
||||
training: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
inputs = input_values_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_values=input_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
extract_features = self.feature_extractor(
|
||||
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
|
||||
)
|
||||
extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
|
||||
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1))
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
|
||||
|
||||
attention_mask = tf.sequence_mask(
|
||||
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
|
||||
)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"])
|
||||
hidden_states, extract_features = self.feature_projection(extract_features, training=training)
|
||||
|
||||
mask_time_indices = kwargs.get("mask_time_indices", None)
|
||||
if inputs["training"]:
|
||||
if training:
|
||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||
|
||||
return TFWav2Vec2BaseModelOutput(
|
||||
@ -1460,6 +1323,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_values: tf.Tensor,
|
||||
@ -1501,9 +1365,11 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
|
||||
>>> hidden_states = model(input_values).last_hidden_state
|
||||
```"""
|
||||
|
||||
inputs = input_values_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||
return_dict = return_dict if return_dict else self.config.return_dict
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values=input_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -1516,27 +1382,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
|
||||
training=training,
|
||||
)
|
||||
|
||||
inputs["output_hidden_states"] = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
|
||||
)
|
||||
inputs["output_attentions"] = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
|
||||
)
|
||||
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values=inputs["input_values"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output):
|
||||
@ -1642,9 +1487,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
```"""
|
||||
inputs = input_values_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values=input_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -1656,21 +1500,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values=inputs["input_values"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
@ -1679,9 +1510,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
||||
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||
|
||||
attention_mask = (
|
||||
inputs["attention_mask"]
|
||||
if inputs["attention_mask"] is not None
|
||||
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
|
||||
attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
|
||||
)
|
||||
input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
|
||||
|
||||
@ -1708,7 +1537,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
||||
else:
|
||||
loss = None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
@ -304,18 +304,15 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
@unittest.skip(reason="Hubert has no input embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Hubert cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
@unittest.skip(reason="Hubert has no tokens embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
@unittest.skip(reason="Hubert has no input embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@ -324,10 +321,6 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip("Loss shapes for CTC don't match the base test.")
|
||||
def test_loss_computation(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@ -426,29 +419,36 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
@unittest.skip(reason="Hubert has no input embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Hubert cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
@unittest.skip(reason="Hubert has no tokens embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
@unittest.skip(reason="Hubert has no input embeddings or get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_dataset_conversion(self):
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip("Loss shapes for CTC don't match the base test.")
|
||||
def test_loss_computation(self):
|
||||
pass
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_keras_fit(self):
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_keras_fit()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
|
||||
@require_tf
|
||||
|
@ -369,18 +369,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@ -389,13 +386,19 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_dataset_conversion(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
|
||||
@require_tf
|
||||
@ -497,18 +500,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@ -517,13 +517,19 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_dataset_conversion(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
|
||||
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
default_batch_size = self.model_tester.batch_size
|
||||
self.model_tester.batch_size = 2
|
||||
super().test_dataset_conversion()
|
||||
self.model_tester.batch_size = default_batch_size
|
||||
|
||||
|
||||
@require_tf
|
||||
|
Loading…
Reference in New Issue
Block a user