mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Replace input_values_processing with unpack_inputs (#21502)
* Replace input_values_prrocessing with unpack_inputs * Skip test failing with OOM * Update tests
This commit is contained in:
parent
557125637d
commit
cb56590111
@ -13,9 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" TensorFlow Hubert model."""
|
""" TensorFlow Hubert model."""
|
||||||
import inspect
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,10 +21,14 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||||
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
from ...modeling_tf_utils import (
|
||||||
|
TFPreTrainedModel,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
unpack_inputs,
|
||||||
|
)
|
||||||
from ...tf_utils import shape_list, stable_softmax
|
from ...tf_utils import shape_list, stable_softmax
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
logging,
|
logging,
|
||||||
@ -47,124 +49,6 @@ TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
LARGE_NEGATIVE = -1e8
|
LARGE_NEGATIVE = -1e8
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.input_values_processing
|
|
||||||
def input_values_processing(func, config, input_values, **kwargs):
|
|
||||||
"""
|
|
||||||
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
|
|
||||||
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
|
|
||||||
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
|
|
||||||
training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func (`callable`):
|
|
||||||
The callable function of the TensorFlow model.
|
|
||||||
config ([`PretrainedConfig`]):
|
|
||||||
The config of the running model.
|
|
||||||
**kwargs:
|
|
||||||
The inputs of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Two lists, one for the missing layers, and another one for the unexpected layers.
|
|
||||||
"""
|
|
||||||
signature = dict(inspect.signature(func).parameters)
|
|
||||||
signature.pop("kwargs", None)
|
|
||||||
signature.pop("self", None)
|
|
||||||
parameter_names = list(signature.keys())
|
|
||||||
output = {}
|
|
||||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, allowed_types) or v is None:
|
|
||||||
output[k] = v
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
|
||||||
|
|
||||||
if isinstance(input_values, (tuple, list)):
|
|
||||||
for i, input in enumerate(input_values):
|
|
||||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
|
||||||
if type(input) == tf.Tensor:
|
|
||||||
# Tensor names have always the pattern `name:id` then we check only the
|
|
||||||
# `name` part
|
|
||||||
tensor_name = input.name.split(":")[0]
|
|
||||||
|
|
||||||
if tensor_name in parameter_names:
|
|
||||||
output[tensor_name] = input
|
|
||||||
else:
|
|
||||||
output[parameter_names[i]] = input
|
|
||||||
elif isinstance(input, allowed_types) or input is None:
|
|
||||||
output[parameter_names[i]] = input
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
|
|
||||||
f" {parameter_names[i]}."
|
|
||||||
)
|
|
||||||
elif isinstance(input_values, Mapping):
|
|
||||||
if "inputs" in input_values:
|
|
||||||
warnings.warn(
|
|
||||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
|
|
||||||
" instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
output["input_values"] = input_values.pop("inputs")
|
|
||||||
|
|
||||||
if "decoder_cached_states" in input_values:
|
|
||||||
warnings.warn(
|
|
||||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
|
|
||||||
" `past_key_values` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
output["past_key_values"] = input_values.pop("decoder_cached_states")
|
|
||||||
|
|
||||||
for k, v in dict(input_values).items():
|
|
||||||
if isinstance(v, allowed_types) or v is None:
|
|
||||||
output[k] = v
|
|
||||||
elif k not in parameter_names and "args" not in parameter_names:
|
|
||||||
logger.warning(
|
|
||||||
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
|
||||||
else:
|
|
||||||
if isinstance(input_values, tf.Tensor) or input_values is None:
|
|
||||||
output[parameter_names[0]] = input_values
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
|
|
||||||
f" {parameter_names[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
for name in parameter_names:
|
|
||||||
if name not in list(output.keys()) and name != "args":
|
|
||||||
output[name] = kwargs.pop(name, signature[name].default)
|
|
||||||
|
|
||||||
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
|
|
||||||
# So to respect the proper output we have to add this exception
|
|
||||||
if "args" in output:
|
|
||||||
if output["args"] is not None and type(output["args"]) == tf.Tensor:
|
|
||||||
tensor_name = output["args"].name.split(":")[0]
|
|
||||||
output[tensor_name] = output["args"]
|
|
||||||
else:
|
|
||||||
# `args` in this case is always the first parameter, then `input_values`
|
|
||||||
output["input_values"] = output["args"]
|
|
||||||
|
|
||||||
del output["args"]
|
|
||||||
|
|
||||||
if "kwargs" in output:
|
|
||||||
del output["kwargs"]
|
|
||||||
|
|
||||||
boolean_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in output.items()
|
|
||||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
|
||||||
}
|
|
||||||
|
|
||||||
output.update(booleans_processing(config=config, **boolean_dict))
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
|
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
|
||||||
def _sample_without_replacement(distribution, num_samples):
|
def _sample_without_replacement(distribution, num_samples):
|
||||||
"""
|
"""
|
||||||
@ -1208,6 +1092,7 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_values: tf.Tensor,
|
input_values: tf.Tensor,
|
||||||
@ -1222,51 +1107,33 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
|
|||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
inputs = input_values_processing(
|
hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
|
||||||
func=self.call,
|
|
||||||
config=self.config,
|
|
||||||
input_values=input_values,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
training=training,
|
|
||||||
kwargs_call=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.feature_extractor(
|
if attention_mask is not None:
|
||||||
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None:
|
|
||||||
# compute real output lengths according to convolution formula
|
# compute real output lengths according to convolution formula
|
||||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
|
||||||
|
|
||||||
attention_mask = tf.sequence_mask(
|
attention_mask = tf.sequence_mask(
|
||||||
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
|
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
|
hidden_states = self.feature_projection(hidden_states, training=training)
|
||||||
|
|
||||||
mask_time_indices = kwargs.get("mask_time_indices", None)
|
mask_time_indices = kwargs.get("mask_time_indices", None)
|
||||||
if inputs["training"]:
|
if training:
|
||||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=return_dict,
|
||||||
training=inputs["training"],
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = encoder_outputs[0]
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not return_dict:
|
||||||
return (hidden_states,) + encoder_outputs[1:]
|
return (hidden_states,) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutput(
|
||||||
@ -1428,6 +1295,7 @@ class TFHubertModel(TFHubertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_values: tf.Tensor,
|
input_values: tf.Tensor,
|
||||||
@ -1469,9 +1337,11 @@ class TFHubertModel(TFHubertPreTrainedModel):
|
|||||||
>>> hidden_states = model(input_values).last_hidden_state
|
>>> hidden_states = model(input_values).last_hidden_state
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
inputs = input_values_processing(
|
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||||
func=self.call,
|
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||||
config=self.config,
|
return_dict = return_dict if return_dict else self.config.return_dict
|
||||||
|
|
||||||
|
outputs = self.hubert(
|
||||||
input_values=input_values,
|
input_values=input_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@ -1484,27 +1354,6 @@ class TFHubertModel(TFHubertPreTrainedModel):
|
|||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs["output_hidden_states"] = (
|
|
||||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
inputs["output_attentions"] = (
|
|
||||||
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
|
|
||||||
)
|
|
||||||
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
|
|
||||||
|
|
||||||
outputs = self.hubert(
|
|
||||||
input_values=inputs["input_values"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
token_type_ids=inputs["token_type_ids"],
|
|
||||||
position_ids=inputs["position_ids"],
|
|
||||||
head_mask=inputs["head_mask"],
|
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
|
||||||
output_attentions=inputs["output_attentions"],
|
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
|
||||||
return_dict=inputs["return_dict"],
|
|
||||||
training=inputs["training"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def serving_output(self, output):
|
def serving_output(self, output):
|
||||||
@ -1548,6 +1397,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_values: tf.Tensor,
|
input_values: tf.Tensor,
|
||||||
@ -1605,9 +1455,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
|||||||
|
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
>>> loss = model(input_values, labels=labels).loss
|
||||||
```"""
|
```"""
|
||||||
inputs = input_values_processing(
|
|
||||||
func=self.call,
|
outputs = self.hubert(
|
||||||
config=self.config,
|
|
||||||
input_values=input_values,
|
input_values=input_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@ -1619,21 +1468,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.hubert(
|
|
||||||
input_values=inputs["input_values"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
token_type_ids=inputs["token_type_ids"],
|
|
||||||
position_ids=inputs["position_ids"],
|
|
||||||
head_mask=inputs["head_mask"],
|
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
|
||||||
output_attentions=inputs["output_attentions"],
|
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
|
||||||
return_dict=inputs["return_dict"],
|
|
||||||
training=inputs["training"],
|
|
||||||
)
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
|
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
@ -1642,9 +1478,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
|||||||
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
inputs["attention_mask"]
|
attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
|
||||||
if inputs["attention_mask"] is not None
|
|
||||||
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
|
|
||||||
)
|
)
|
||||||
input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
|
input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
|
||||||
|
|
||||||
@ -1671,7 +1505,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
@ -14,9 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" TensorFlow Wav2Vec2 model."""
|
""" TensorFlow Wav2Vec2 model."""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -27,7 +25,6 @@ from ...activations_tf import get_tf_activation
|
|||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
booleans_processing,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
unpack_inputs,
|
unpack_inputs,
|
||||||
@ -91,123 +88,6 @@ class TFWav2Vec2BaseModelOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
def input_values_processing(func, config, input_values, **kwargs):
|
|
||||||
"""
|
|
||||||
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
|
|
||||||
has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,),
|
|
||||||
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
|
|
||||||
training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func (`callable`):
|
|
||||||
The callable function of the TensorFlow model.
|
|
||||||
config ([`PretrainedConfig`]):
|
|
||||||
The config of the running model.
|
|
||||||
**kwargs:
|
|
||||||
The inputs of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Two lists, one for the missing layers, and another one for the unexpected layers.
|
|
||||||
"""
|
|
||||||
signature = dict(inspect.signature(func).parameters)
|
|
||||||
signature.pop("kwargs", None)
|
|
||||||
signature.pop("self", None)
|
|
||||||
parameter_names = list(signature.keys())
|
|
||||||
output = {}
|
|
||||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, allowed_types) or v is None:
|
|
||||||
output[k] = v
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
|
||||||
|
|
||||||
if isinstance(input_values, (tuple, list)):
|
|
||||||
for i, input in enumerate(input_values):
|
|
||||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
|
||||||
if type(input) == tf.Tensor:
|
|
||||||
# Tensor names have always the pattern `name:id` then we check only the
|
|
||||||
# `name` part
|
|
||||||
tensor_name = input.name.split(":")[0]
|
|
||||||
|
|
||||||
if tensor_name in parameter_names:
|
|
||||||
output[tensor_name] = input
|
|
||||||
else:
|
|
||||||
output[parameter_names[i]] = input
|
|
||||||
elif isinstance(input, allowed_types) or input is None:
|
|
||||||
output[parameter_names[i]] = input
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
|
|
||||||
f" {parameter_names[i]}."
|
|
||||||
)
|
|
||||||
elif isinstance(input_values, Mapping):
|
|
||||||
if "inputs" in input_values:
|
|
||||||
warnings.warn(
|
|
||||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
|
|
||||||
" instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
output["input_values"] = input_values.pop("inputs")
|
|
||||||
|
|
||||||
if "decoder_cached_states" in input_values:
|
|
||||||
warnings.warn(
|
|
||||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
|
|
||||||
" `past_key_values` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
output["past_key_values"] = input_values.pop("decoder_cached_states")
|
|
||||||
|
|
||||||
for k, v in dict(input_values).items():
|
|
||||||
if isinstance(v, allowed_types) or v is None:
|
|
||||||
output[k] = v
|
|
||||||
elif k not in parameter_names and "args" not in parameter_names:
|
|
||||||
logger.warning(
|
|
||||||
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
|
||||||
else:
|
|
||||||
if isinstance(input_values, tf.Tensor) or input_values is None:
|
|
||||||
output[parameter_names[0]] = input_values
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
|
|
||||||
f" {parameter_names[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
for name in parameter_names:
|
|
||||||
if name not in list(output.keys()) and name != "args":
|
|
||||||
output[name] = kwargs.pop(name, signature[name].default)
|
|
||||||
|
|
||||||
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
|
|
||||||
# So to respect the proper output we have to add this exception
|
|
||||||
if "args" in output:
|
|
||||||
if output["args"] is not None and type(output["args"]) == tf.Tensor:
|
|
||||||
tensor_name = output["args"].name.split(":")[0]
|
|
||||||
output[tensor_name] = output["args"]
|
|
||||||
else:
|
|
||||||
# `args` in this case is always the first parameter, then `input_values`
|
|
||||||
output["input_values"] = output["args"]
|
|
||||||
|
|
||||||
del output["args"]
|
|
||||||
|
|
||||||
if "kwargs" in output:
|
|
||||||
del output["kwargs"]
|
|
||||||
|
|
||||||
boolean_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in output.items()
|
|
||||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
|
||||||
}
|
|
||||||
|
|
||||||
output.update(booleans_processing(config=config, **boolean_dict))
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_without_replacement(distribution, num_samples):
|
def _sample_without_replacement(distribution, num_samples):
|
||||||
"""
|
"""
|
||||||
Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
|
Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
|
||||||
@ -1238,6 +1118,7 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_values: tf.Tensor,
|
input_values: tf.Tensor,
|
||||||
@ -1252,52 +1133,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
|
|||||||
training: bool = False,
|
training: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
inputs = input_values_processing(
|
extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
|
||||||
func=self.call,
|
|
||||||
config=self.config,
|
|
||||||
input_values=input_values,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
training=training,
|
|
||||||
kwargs_call=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
extract_features = self.feature_extractor(
|
|
||||||
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
|
|
||||||
)
|
|
||||||
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1))
|
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1))
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None:
|
if attention_mask is not None:
|
||||||
# compute real output lengths according to convolution formula
|
# compute real output lengths according to convolution formula
|
||||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
|
||||||
|
|
||||||
attention_mask = tf.sequence_mask(
|
attention_mask = tf.sequence_mask(
|
||||||
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
|
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"])
|
hidden_states, extract_features = self.feature_projection(extract_features, training=training)
|
||||||
|
|
||||||
mask_time_indices = kwargs.get("mask_time_indices", None)
|
mask_time_indices = kwargs.get("mask_time_indices", None)
|
||||||
if inputs["training"]:
|
if training:
|
||||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=inputs["return_dict"],
|
return_dict=return_dict,
|
||||||
training=inputs["training"],
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = encoder_outputs[0]
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not return_dict:
|
||||||
return (hidden_states, extract_features) + encoder_outputs[1:]
|
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||||
|
|
||||||
return TFWav2Vec2BaseModelOutput(
|
return TFWav2Vec2BaseModelOutput(
|
||||||
@ -1460,6 +1323,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_values: tf.Tensor,
|
input_values: tf.Tensor,
|
||||||
@ -1501,9 +1365,11 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
|
|||||||
>>> hidden_states = model(input_values).last_hidden_state
|
>>> hidden_states = model(input_values).last_hidden_state
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
inputs = input_values_processing(
|
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||||
func=self.call,
|
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||||
config=self.config,
|
return_dict = return_dict if return_dict else self.config.return_dict
|
||||||
|
|
||||||
|
outputs = self.wav2vec2(
|
||||||
input_values=input_values,
|
input_values=input_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@ -1516,27 +1382,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
|
|||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs["output_hidden_states"] = (
|
|
||||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
inputs["output_attentions"] = (
|
|
||||||
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
|
|
||||||
)
|
|
||||||
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
|
|
||||||
|
|
||||||
outputs = self.wav2vec2(
|
|
||||||
input_values=inputs["input_values"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
token_type_ids=inputs["token_type_ids"],
|
|
||||||
position_ids=inputs["position_ids"],
|
|
||||||
head_mask=inputs["head_mask"],
|
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
|
||||||
output_attentions=inputs["output_attentions"],
|
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
|
||||||
return_dict=inputs["return_dict"],
|
|
||||||
training=inputs["training"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def serving_output(self, output):
|
def serving_output(self, output):
|
||||||
@ -1642,9 +1487,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
>>> loss = model(input_values, labels=labels).loss
|
||||||
```"""
|
```"""
|
||||||
inputs = input_values_processing(
|
|
||||||
func=self.call,
|
outputs = self.wav2vec2(
|
||||||
config=self.config,
|
|
||||||
input_values=input_values,
|
input_values=input_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@ -1656,21 +1500,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.wav2vec2(
|
|
||||||
input_values=inputs["input_values"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
token_type_ids=inputs["token_type_ids"],
|
|
||||||
position_ids=inputs["position_ids"],
|
|
||||||
head_mask=inputs["head_mask"],
|
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
|
||||||
output_attentions=inputs["output_attentions"],
|
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
|
||||||
return_dict=inputs["return_dict"],
|
|
||||||
training=inputs["training"],
|
|
||||||
)
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
|
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
@ -1679,9 +1510,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||||||
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
inputs["attention_mask"]
|
attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
|
||||||
if inputs["attention_mask"] is not None
|
|
||||||
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
|
|
||||||
)
|
)
|
||||||
input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
|
input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
|
||||||
|
|
||||||
@ -1708,7 +1537,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
@ -304,18 +304,15 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
# Hubert has no inputs_embeds
|
@unittest.skip(reason="Hubert has no input embeddings")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Hubert cannot resize token embeddings
|
@unittest.skip(reason="Hubert has no tokens embeddings")
|
||||||
# since it has no tokens embeddings
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Hubert has no inputs_embeds
|
@unittest.skip(reason="Hubert has no input embeddings")
|
||||||
# and thus the `get_input_embeddings` fn
|
|
||||||
# is not implemented
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -324,10 +321,6 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
|
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip("Loss shapes for CTC don't match the base test.")
|
|
||||||
def test_loss_computation(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
@ -426,29 +419,36 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
# Hubert has no inputs_embeds
|
@unittest.skip(reason="Hubert has no input embeddings")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Hubert cannot resize token embeddings
|
@unittest.skip(reason="Hubert has no tokens embeddings")
|
||||||
# since it has no tokens embeddings
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Hubert has no inputs_embeds
|
@unittest.skip(reason="Hubert has no input embeddings or get_input_embeddings method")
|
||||||
# and thus the `get_input_embeddings` fn
|
|
||||||
# is not implemented
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
|
def test_dataset_conversion(self):
|
||||||
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip("Loss shapes for CTC don't match the base test.")
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
def test_loss_computation(self):
|
def test_keras_fit(self):
|
||||||
pass
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_keras_fit()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
|
@ -369,18 +369,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Wav2Vec2 cannot resize token embeddings
|
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
|
||||||
# since it has no tokens embeddings
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||||
# and thus the `get_input_embeddings` fn
|
|
||||||
# is not implemented
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -389,13 +386,19 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
def test_dataset_conversion(self):
|
def test_dataset_conversion(self):
|
||||||
pass
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
pass
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@ -497,18 +500,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Wav2Vec2 cannot resize token embeddings
|
@unittest.skip(reason="Wav2Vec2 has no tokens embeddings")
|
||||||
# since it has no tokens embeddings
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
@unittest.skip(reason="Wav2Vec2 has no input embeddings")
|
||||||
# and thus the `get_input_embeddings` fn
|
|
||||||
# is not implemented
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -517,13 +517,19 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!")
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
def test_dataset_conversion(self):
|
def test_dataset_conversion(self):
|
||||||
pass
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
@unittest.skip(reason="Training goes OOM and crashes with the default options!")
|
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
|
||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
pass
|
default_batch_size = self.model_tester.batch_size
|
||||||
|
self.model_tester.batch_size = 2
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
self.model_tester.batch_size = default_batch_size
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
|
Loading…
Reference in New Issue
Block a user