From 312b104ff65514736c0475814fec19e47425b0b5 Mon Sep 17 00:00:00 2001 From: raghavanone <115454562+raghavanone@users.noreply.github.com> Date: Fri, 5 May 2023 22:53:46 +0530 Subject: [PATCH] Add FlaxWhisperForAudioClassification model (#23173) * Add FlaxWhisperForAudioClassification model * Add models to init * Add models to init * Fix copies * Fix automapping * Fix failing test --- docs/source/en/model_doc/whisper.mdx | 6 + src/transformers/__init__.py | 8 +- .../models/auto/modeling_flax_auto.py | 5 + src/transformers/models/whisper/__init__.py | 2 + .../models/whisper/modeling_flax_whisper.py | 161 ++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 7 + .../whisper/test_modeling_flax_whisper.py | 205 +++++++++++++++++- tests/models/whisper/test_modeling_whisper.py | 8 +- 8 files changed, 395 insertions(+), 7 deletions(-) diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 22b08e4e61b..52a8b5953c6 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -105,3 +105,9 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] FlaxWhisperForConditionalGeneration - __call__ + +## FlaxWhisperForAudioClassification + +[[autodoc]] FlaxWhisperForAudioClassification + - __call__ + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7bf322ca8e1..b0766b0946c 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3779,6 +3779,7 @@ else: "FlaxWhisperForConditionalGeneration", "FlaxWhisperModel", "FlaxWhisperPreTrainedModel", + "FlaxWhisperForAudioClassification", ] ) _import_structure["models.xglm"].extend( @@ -6903,7 +6904,12 @@ if TYPE_CHECKING: FlaxWav2Vec2Model, FlaxWav2Vec2PreTrainedModel, ) - from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel + from .models.whisper import ( + FlaxWhisperForAudioClassification, + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + FlaxWhisperPreTrainedModel, + ) from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xlm_roberta import ( FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 755d1f07a34..e3b8d9cf5b5 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -229,6 +229,11 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( ] ) +FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("whisper", "FlaxWhisperForAudioClassification"), + ] +) FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index 3b6015a56f6..cd962478e34 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -75,6 +75,7 @@ else: "FlaxWhisperForConditionalGeneration", "FlaxWhisperModel", "FlaxWhisperPreTrainedModel", + "FlaxWhisperForAudioClassification", ] @@ -126,6 +127,7 @@ if TYPE_CHECKING: pass else: from .modeling_flax_whisper import ( + FlaxWhisperForAudioClassification, FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel, diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index b8d6f07242d..1a994acea4d 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -36,6 +36,7 @@ from ...modeling_flax_outputs import ( FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, + FlaxSequenceClassifierOutput, ) from ...modeling_flax_utils import ( ACT2FN, @@ -1506,3 +1507,163 @@ overwrite_call_docstring( append_replace_return_docstrings( FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC ) + + +class FlaxWhisperForAudioClassificationModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self) -> None: + self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype) + self.config.is_encoder_decoder = False + num_layers = self.config.num_hidden_layers + 1 + if self.config.use_weighted_layer_sum: + self.layer_weights = jnp.repeat(1 / num_layers, num_layers) + self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_features, + encoder_outputs=None, + output_attentions=None, + output_hidden_states: bool = True, + return_dict: bool = True, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = jnp.stack(encoder_outputs, axis=1) + norm_weights = jax.nn.softmax(self.layer_weights, axis=-1) + hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = jnp.mean(hidden_states, axis=1) + + logits = self.classifier(pooled_output) + + if not return_dict: + return (logits,) + encoder_outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING) +class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel): + module_class = FlaxWhisperForAudioClassificationModule + dtype: jnp.dtype = jnp.float32 + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_features = jnp.zeros(input_shape, dtype="f4") + input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_features=input_features, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + def __call__( + self, + input_features: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + input_features=jnp.array(input_features, dtype="f4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + +FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r""" + Returns: + + Transcription example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification + >>> from datasets import load_dataset + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + >>> model = FlaxWhisperForAudioClassification.from_pretrained( + ... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True + ... ) + >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) + + >>> sample = next(iter(ds)) + + >>> inputs = feature_extractor( + ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np" + ... ) + >>> input_features = inputs.input_features + + >>> logits = model(input_features).logits + + >>> predicted_class_ids = jnp.argmax(logits).item() + >>> predicted_label = model.config.id2label[predicted_class_ids] + >>> predicted_label + 'af_za' + ``` +""" + +overwrite_call_docstring( + FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index eeec3277492..ce571bc9f8d 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -1182,6 +1182,13 @@ class FlaxWav2Vec2PreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxWhisperForAudioClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxWhisperForConditionalGeneration(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 3f1e201d72d..79a2c51039a 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import functools import inspect import tempfile @@ -41,6 +39,7 @@ if is_flax_available(): from transformers import ( FLAX_MODEL_MAPPING, + FlaxWhisperForAudioClassification, FlaxWhisperForConditionalGeneration, FlaxWhisperModel, WhisperFeatureExtractor, @@ -704,3 +703,205 @@ class FlaxWhisperModelIntegrationTest(unittest.TestCase): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + +class FlaxWhisperEncoderModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=60, + is_training=True, + use_labels=True, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + input_channels=1, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + max_source_positions=30, + num_mel_bins=80, + num_conv_layers=1, + suppress_tokens=None, + begin_suppress_tokens=None, + classifier_proj_size=4, + num_labels=2, + is_encoder_decoder=False, + is_decoder=False, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_mel_bins = num_mel_bins + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.num_conv_layers = num_conv_layers + self.suppress_tokens = suppress_tokens + self.begin_suppress_tokens = begin_suppress_tokens + self.classifier_proj_size = classifier_proj_size + self.num_labels = num_labels + self.is_encoder_decoder = is_encoder_decoder + self.is_decoder = is_decoder + + def get_config(self): + return WhisperConfig( + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + input_channels=self.input_channels, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + max_source_positions=self.max_source_positions, + decoder_ffn_dim=self.hidden_size, + encoder_ffn_dim=self.hidden_size, + suppress_tokens=self.suppress_tokens, + begin_suppress_tokens=self.begin_suppress_tokens, + classifier_proj_size=self.classifier_proj_size, + num_labels=self.num_labels, + is_encoder_decoder=self.is_encoder_decoder, + is_decoder=self.is_decoder, + ) + + def prepare_whisper_encoder_inputs_dict( + self, + input_features, + ): + return { + "input_features": input_features, + } + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length]) + + config = self.get_config() + inputs_dict = self.prepare_whisper_encoder_inputs_dict( + input_features=input_features, + ) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def get_subsampled_output_lengths(self, input_lengths): + """ + Computes the output length of the convolutional layers + """ + + for i in range(self.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + @property + def encoder_seq_length(self): + return self.get_subsampled_output_lengths(self.seq_length) + + +@require_flax +class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = (FlaxWhisperForAudioClassification,) if is_flax_available() else () + is_encoder_decoder = False + fx_compatible = False + test_pruning = False + test_missing_keys = False + + input_name = "input_features" + + def setUp(self): + self.model_tester = FlaxWhisperEncoderModelTester(self) + _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + self.init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + self.all_model_classes = ( + make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes + ) + self.config_tester = ConfigTester(self, config_class=WhisperConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + # overwrite because of `input_features` + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(input_features, **kwargs): + return model(input_features=input_features, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + # overwrite because of `input_features` + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_features", "attention_mask", "output_attentions"] + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + def test_inputs_embeds(self): + pass + + # WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented + def test_model_common_attributes(self): + pass + + # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings + def test_resize_tokens_embeddings(self): + pass + + # WhisperEncoder does not have any base model + def test_save_load_to_base(self): + pass + + # WhisperEncoder does not have any base model + def test_save_load_from_base(self): + pass + + # WhisperEncoder does not have any base model + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + pass + + # WhisperEncoder does not have any base model + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + pass + + # WhisperEncoder does not have any base model + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + pass diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0591c6f4643..0b5b375e9dd 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -95,7 +95,7 @@ class WhisperModelTester: self, parent, batch_size=13, - seq_length=60, + seq_length=1500, is_training=True, use_labels=False, vocab_size=200, @@ -107,7 +107,7 @@ class WhisperModelTester: hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=20, - max_source_positions=30, + max_source_positions=750, max_target_positions=40, bos_token_id=98, eos_token_id=98, @@ -1434,7 +1434,7 @@ class WhisperEncoderModelTester: self, parent, batch_size=13, - seq_length=60, + seq_length=3000, is_training=True, use_labels=True, hidden_size=16, @@ -1445,7 +1445,7 @@ class WhisperEncoderModelTester: hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=20, - max_source_positions=30, + max_source_positions=1500, num_mel_bins=80, num_conv_layers=1, suppress_tokens=None,