diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index f4722532e8a..24cbde9af7c 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ TensorFlow Hubert model.""" -import inspect import warnings -from collections.abc import Mapping from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -23,10 +21,14 @@ import tensorflow as tf from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput -from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable +from ...modeling_tf_utils import ( + TFPreTrainedModel, + get_initializer, + keras_serializable, + unpack_inputs, +) from ...tf_utils import shape_list, stable_softmax from ...utils import ( - ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -47,124 +49,6 @@ TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ LARGE_NEGATIVE = -1e8 -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.input_values_processing -def input_values_processing(func, config, input_values, **kwargs): - """ - Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input - has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,), - dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the - training. - - Args: - func (`callable`): - The callable function of the TensorFlow model. - config ([`PretrainedConfig`]): - The config of the running model. - **kwargs: - The inputs of the model. - - Returns: - Two lists, one for the missing layers, and another one for the unexpected layers. - """ - signature = dict(inspect.signature(func).parameters) - signature.pop("kwargs", None) - signature.pop("self", None) - parameter_names = list(signature.keys()) - output = {} - allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) - - for k, v in kwargs.items(): - if isinstance(v, allowed_types) or v is None: - output[k] = v - else: - raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - - if isinstance(input_values, (tuple, list)): - for i, input in enumerate(input_values): - # EagerTensors don't allow to use the .name property so we check for a real Tensor - if type(input) == tf.Tensor: - # Tensor names have always the pattern `name:id` then we check only the - # `name` part - tensor_name = input.name.split(":")[0] - - if tensor_name in parameter_names: - output[tensor_name] = input - else: - output[parameter_names[i]] = input - elif isinstance(input, allowed_types) or input is None: - output[parameter_names[i]] = input - else: - raise ValueError( - f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" - f" {parameter_names[i]}." - ) - elif isinstance(input_values, Mapping): - if "inputs" in input_values: - warnings.warn( - "The `inputs` argument is deprecated and will be removed in a future version, use `input_values`" - " instead.", - FutureWarning, - ) - - output["input_values"] = input_values.pop("inputs") - - if "decoder_cached_states" in input_values: - warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" - " `past_key_values` instead.", - FutureWarning, - ) - output["past_key_values"] = input_values.pop("decoder_cached_states") - - for k, v in dict(input_values).items(): - if isinstance(v, allowed_types) or v is None: - output[k] = v - elif k not in parameter_names and "args" not in parameter_names: - logger.warning( - f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored." - ) - continue - else: - raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - else: - if isinstance(input_values, tf.Tensor) or input_values is None: - output[parameter_names[0]] = input_values - else: - raise ValueError( - f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for" - f" {parameter_names[0]}." - ) - - for name in parameter_names: - if name not in list(output.keys()) and name != "args": - output[name] = kwargs.pop(name, signature[name].default) - - # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) - # So to respect the proper output we have to add this exception - if "args" in output: - if output["args"] is not None and type(output["args"]) == tf.Tensor: - tensor_name = output["args"].name.split(":")[0] - output[tensor_name] = output["args"] - else: - # `args` in this case is always the first parameter, then `input_values` - output["input_values"] = output["args"] - - del output["args"] - - if "kwargs" in output: - del output["kwargs"] - - boolean_dict = { - k: v - for k, v in output.items() - if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] - } - - output.update(booleans_processing(config=config, **boolean_dict)) - - return output - - # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement def _sample_without_replacement(distribution, num_samples): """ @@ -1208,6 +1092,7 @@ class TFHubertMainLayer(tf.keras.layers.Layer): return hidden_states + @unpack_inputs def call( self, input_values: tf.Tensor, @@ -1222,51 +1107,33 @@ class TFHubertMainLayer(tf.keras.layers.Layer): training: bool = False, **kwargs: Any, ): - inputs = input_values_processing( - func=self.call, - config=self.config, - input_values=input_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) + hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training) - hidden_states = self.feature_extractor( - tf.cast(inputs["input_values"], tf.float32), training=inputs["training"] - ) - - if inputs["attention_mask"] is not None: + if attention_mask is not None: # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) + output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1)) attention_mask = tf.sequence_mask( output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype ) - hidden_states = self.feature_projection(hidden_states, training=inputs["training"]) + hidden_states = self.feature_projection(hidden_states, training=training) mask_time_indices = kwargs.get("mask_time_indices", None) - if inputs["training"]: + if training: hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) hidden_states = encoder_outputs[0] - if not inputs["return_dict"]: + if not return_dict: return (hidden_states,) + encoder_outputs[1:] return TFBaseModelOutput( @@ -1428,6 +1295,7 @@ class TFHubertModel(TFHubertPreTrainedModel): @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs def call( self, input_values: tf.Tensor, @@ -1469,9 +1337,11 @@ class TFHubertModel(TFHubertPreTrainedModel): >>> hidden_states = model(input_values).last_hidden_state ```""" - inputs = input_values_processing( - func=self.call, - config=self.config, + output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states + output_attentions = output_attentions if output_attentions else self.config.output_attentions + return_dict = return_dict if return_dict else self.config.return_dict + + outputs = self.hubert( input_values=input_values, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1484,27 +1354,6 @@ class TFHubertModel(TFHubertPreTrainedModel): training=training, ) - inputs["output_hidden_states"] = ( - inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states - ) - inputs["output_attentions"] = ( - inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions - ) - inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict - - outputs = self.hubert( - input_values=inputs["input_values"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - return outputs def serving_output(self, output): @@ -1548,6 +1397,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel): @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs def call( self, input_values: tf.Tensor, @@ -1605,9 +1455,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel): >>> loss = model(input_values, labels=labels).loss ```""" - inputs = input_values_processing( - func=self.call, - config=self.config, + + outputs = self.hubert( input_values=input_values, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1619,21 +1468,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel): return_dict=return_dict, training=training, ) - - outputs = self.hubert( - input_values=inputs["input_values"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) logits = self.lm_head(hidden_states) @@ -1642,9 +1478,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel): raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") attention_mask = ( - inputs["attention_mask"] - if inputs["attention_mask"] is not None - else tf.ones_like(inputs["input_values"], dtype=tf.float32) + attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32) ) input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) @@ -1671,7 +1505,7 @@ class TFHubertForCTC(TFHubertPreTrainedModel): else: loss = None - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index a3f2fd0e1e1..64defa33597 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -14,9 +14,7 @@ # limitations under the License. """ TensorFlow Wav2Vec2 model.""" -import inspect import warnings -from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union @@ -27,7 +25,6 @@ from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_utils import ( TFPreTrainedModel, - booleans_processing, get_initializer, keras_serializable, unpack_inputs, @@ -91,123 +88,6 @@ class TFWav2Vec2BaseModelOutput(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None -def input_values_processing(func, config, input_values, **kwargs): - """ - Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input - has to be named accordingly to the parameters name, i.e. `input_values = tf.keras.Input(shape=(128,), - dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the - training. - - Args: - func (`callable`): - The callable function of the TensorFlow model. - config ([`PretrainedConfig`]): - The config of the running model. - **kwargs: - The inputs of the model. - - Returns: - Two lists, one for the missing layers, and another one for the unexpected layers. - """ - signature = dict(inspect.signature(func).parameters) - signature.pop("kwargs", None) - signature.pop("self", None) - parameter_names = list(signature.keys()) - output = {} - allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) - - for k, v in kwargs.items(): - if isinstance(v, allowed_types) or v is None: - output[k] = v - else: - raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - - if isinstance(input_values, (tuple, list)): - for i, input in enumerate(input_values): - # EagerTensors don't allow to use the .name property so we check for a real Tensor - if type(input) == tf.Tensor: - # Tensor names have always the pattern `name:id` then we check only the - # `name` part - tensor_name = input.name.split(":")[0] - - if tensor_name in parameter_names: - output[tensor_name] = input - else: - output[parameter_names[i]] = input - elif isinstance(input, allowed_types) or input is None: - output[parameter_names[i]] = input - else: - raise ValueError( - f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" - f" {parameter_names[i]}." - ) - elif isinstance(input_values, Mapping): - if "inputs" in input_values: - warnings.warn( - "The `inputs` argument is deprecated and will be removed in a future version, use `input_values`" - " instead.", - FutureWarning, - ) - - output["input_values"] = input_values.pop("inputs") - - if "decoder_cached_states" in input_values: - warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" - " `past_key_values` instead.", - FutureWarning, - ) - output["past_key_values"] = input_values.pop("decoder_cached_states") - - for k, v in dict(input_values).items(): - if isinstance(v, allowed_types) or v is None: - output[k] = v - elif k not in parameter_names and "args" not in parameter_names: - logger.warning( - f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored." - ) - continue - else: - raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - else: - if isinstance(input_values, tf.Tensor) or input_values is None: - output[parameter_names[0]] = input_values - else: - raise ValueError( - f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for" - f" {parameter_names[0]}." - ) - - for name in parameter_names: - if name not in list(output.keys()) and name != "args": - output[name] = kwargs.pop(name, signature[name].default) - - # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) - # So to respect the proper output we have to add this exception - if "args" in output: - if output["args"] is not None and type(output["args"]) == tf.Tensor: - tensor_name = output["args"].name.split(":")[0] - output[tensor_name] = output["args"] - else: - # `args` in this case is always the first parameter, then `input_values` - output["input_values"] = output["args"] - - del output["args"] - - if "kwargs" in output: - del output["kwargs"] - - boolean_dict = { - k: v - for k, v in output.items() - if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] - } - - output.update(booleans_processing(config=config, **boolean_dict)) - - return output - - def _sample_without_replacement(distribution, num_samples): """ Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see @@ -1238,6 +1118,7 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): return hidden_states + @unpack_inputs def call( self, input_values: tf.Tensor, @@ -1252,52 +1133,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): training: bool = False, **kwargs: Any, ): - inputs = input_values_processing( - func=self.call, - config=self.config, - input_values=input_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - - extract_features = self.feature_extractor( - tf.cast(inputs["input_values"], tf.float32), training=inputs["training"] - ) + extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training) # extract_features = tf.transpose(extract_features, perm=(0, 2, 1)) - if inputs["attention_mask"] is not None: + if attention_mask is not None: # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) + output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1)) attention_mask = tf.sequence_mask( output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype ) - hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"]) + hidden_states, extract_features = self.feature_projection(extract_features, training=training) mask_time_indices = kwargs.get("mask_time_indices", None) - if inputs["training"]: + if training: hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) hidden_states = encoder_outputs[0] - if not inputs["return_dict"]: + if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return TFWav2Vec2BaseModelOutput( @@ -1460,6 +1323,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs def call( self, input_values: tf.Tensor, @@ -1501,9 +1365,11 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): >>> hidden_states = model(input_values).last_hidden_state ```""" - inputs = input_values_processing( - func=self.call, - config=self.config, + output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states + output_attentions = output_attentions if output_attentions else self.config.output_attentions + return_dict = return_dict if return_dict else self.config.return_dict + + outputs = self.wav2vec2( input_values=input_values, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1516,27 +1382,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): training=training, ) - inputs["output_hidden_states"] = ( - inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states - ) - inputs["output_attentions"] = ( - inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions - ) - inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict - - outputs = self.wav2vec2( - input_values=inputs["input_values"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - return outputs def serving_output(self, output): @@ -1642,9 +1487,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): >>> loss = model(input_values, labels=labels).loss ```""" - inputs = input_values_processing( - func=self.call, - config=self.config, + + outputs = self.wav2vec2( input_values=input_values, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1656,21 +1500,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): return_dict=return_dict, training=training, ) - - outputs = self.wav2vec2( - input_values=inputs["input_values"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) logits = self.lm_head(hidden_states) @@ -1679,9 +1510,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") attention_mask = ( - inputs["attention_mask"] - if inputs["attention_mask"] is not None - else tf.ones_like(inputs["input_values"], dtype=tf.float32) + attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32) ) input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) @@ -1708,7 +1537,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): else: loss = None - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py index 084a4001100..5b2183b2dfe 100644 --- a/tests/models/hubert/test_modeling_tf_hubert.py +++ b/tests/models/hubert/test_modeling_tf_hubert.py @@ -304,18 +304,15 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs) - # Hubert has no inputs_embeds + @unittest.skip(reason="Hubert has no input embeddings") def test_inputs_embeds(self): pass - # Hubert cannot resize token embeddings - # since it has no tokens embeddings + @unittest.skip(reason="Hubert has no tokens embeddings") def test_resize_tokens_embeddings(self): pass - # Hubert has no inputs_embeds - # and thus the `get_input_embeddings` fn - # is not implemented + @unittest.skip(reason="Hubert has no input embeddings") def test_model_common_attributes(self): pass @@ -324,10 +321,6 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase): model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960") self.assertIsNotNone(model) - @unittest.skip("Loss shapes for CTC don't match the base test.") - def test_loss_computation(self): - pass - @require_tf class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): @@ -426,29 +419,36 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs) - # Hubert has no inputs_embeds + @unittest.skip(reason="Hubert has no input embeddings") def test_inputs_embeds(self): pass - # Hubert cannot resize token embeddings - # since it has no tokens embeddings + @unittest.skip(reason="Hubert has no tokens embeddings") def test_resize_tokens_embeddings(self): pass - # Hubert has no inputs_embeds - # and thus the `get_input_embeddings` fn - # is not implemented + @unittest.skip(reason="Hubert has no input embeddings or get_input_embeddings method") def test_model_common_attributes(self): pass + # We override here as passing a full batch of 13 samples results in OOM errors for CTC + def test_dataset_conversion(self): + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_dataset_conversion() + self.model_tester.batch_size = default_batch_size + @slow def test_model_from_pretrained(self): model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft") self.assertIsNotNone(model) - @unittest.skip("Loss shapes for CTC don't match the base test.") - def test_loss_computation(self): - pass + # We override here as passing a full batch of 13 samples results in OOM errors for CTC + def test_keras_fit(self): + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_keras_fit() + self.model_tester.batch_size = default_batch_size @require_tf diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 38e83bcdf9e..42946fce496 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -369,18 +369,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) - # Wav2Vec2 has no inputs_embeds + @unittest.skip(reason="Wav2Vec2 has no input embeddings") def test_inputs_embeds(self): pass - # Wav2Vec2 cannot resize token embeddings - # since it has no tokens embeddings + @unittest.skip(reason="Wav2Vec2 has no tokens embeddings") def test_resize_tokens_embeddings(self): pass - # Wav2Vec2 has no inputs_embeds - # and thus the `get_input_embeddings` fn - # is not implemented + @unittest.skip(reason="Wav2Vec2 has no input embeddings") def test_model_common_attributes(self): pass @@ -389,13 +386,19 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsNotNone(model) - @unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!") + # We override here as passing a full batch of 13 samples results in OOM errors for CTC def test_dataset_conversion(self): - pass + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_dataset_conversion() + self.model_tester.batch_size = default_batch_size - @unittest.skip(reason="Training goes OOM and crashes with the default options!") + # We override here as passing a full batch of 13 samples results in OOM errors for CTC def test_keras_fit(self): - pass + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_dataset_conversion() + self.model_tester.batch_size = default_batch_size @require_tf @@ -497,18 +500,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) - # Wav2Vec2 has no inputs_embeds + @unittest.skip(reason="Wav2Vec2 has no input embeddings") def test_inputs_embeds(self): pass - # Wav2Vec2 cannot resize token embeddings - # since it has no tokens embeddings + @unittest.skip(reason="Wav2Vec2 has no tokens embeddings") def test_resize_tokens_embeddings(self): pass - # Wav2Vec2 has no inputs_embeds - # and thus the `get_input_embeddings` fn - # is not implemented + @unittest.skip(reason="Wav2Vec2 has no input embeddings") def test_model_common_attributes(self): pass @@ -517,13 +517,19 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsNotNone(model) - @unittest.skip(reason="Dataset conversion goes OOM and crashes with the default options!") + # We override here as passing a full batch of 13 samples results in OOM errors for CTC def test_dataset_conversion(self): - pass + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_dataset_conversion() + self.model_tester.batch_size = default_batch_size - @unittest.skip(reason="Training goes OOM and crashes with the default options!") + # We override here as passing a full batch of 13 samples results in OOM errors for CTC def test_keras_fit(self): - pass + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_dataset_conversion() + self.model_tester.batch_size = default_batch_size @require_tf