diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 99999d15474..e77f2e00472 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -404,9 +404,11 @@ class FlaxWav2Vec2FeatureEncoder(nn.Module): def setup(self): self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) - def __call__(self, input_values): + def __call__(self, input_values, freeze_feature_encoder=False): hidden_states = input_values[:, :, None] hidden_states = self.conv_layers(hidden_states) + if freeze_feature_encoder: + hidden_states = jax.lax.stop_gradient(hidden_states) return hidden_states @@ -875,6 +877,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + freeze_feature_encoder: bool = False, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -903,6 +906,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): not train, output_attentions, output_hidden_states, + freeze_feature_encoder, return_dict, rngs=rngs, ) @@ -939,9 +943,10 @@ class FlaxWav2Vec2Module(nn.Module): deterministic=True, output_attentions=None, output_hidden_states=None, + freeze_feature_encoder=False, return_dict=None, ): - extract_features = self.feature_extractor(input_values) + extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder) # make sure that no loss is computed on padded inputs if attention_mask is not None: @@ -1101,6 +1106,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): deterministic=True, output_attentions=None, output_hidden_states=None, + freeze_feature_encoder=False, return_dict=None, ): outputs = self.wav2vec2( @@ -1110,6 +1116,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + freeze_feature_encoder=freeze_feature_encoder, return_dict=return_dict, ) @@ -1232,6 +1239,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module): deterministic: bool = True, output_attentions=None, output_hidden_states=None, + freeze_feature_enocder=False, return_dict=None, ): r""" @@ -1252,6 +1260,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module): output_hidden_states=output_hidden_states, mask_time_indices=mask_time_indices, deterministic=deterministic, + freeze_feature_encoder=freeze_feature_enocder, return_dict=return_dict, ) @@ -1310,6 +1319,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + freeze_feature_encoder: bool = False, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1342,6 +1352,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): not train, output_attentions, output_hidden_states, + freeze_feature_encoder, return_dict, rngs=rngs, ) diff --git a/tests/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/wav2vec2/test_modeling_flax_wav2vec2.py index 6e25ac8281a..42f904b4cc0 100644 --- a/tests/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/wav2vec2/test_modeling_flax_wav2vec2.py @@ -229,6 +229,47 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertEqual(jitted_output.shape, output.shape) + def test_freeze_feature_encoder(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_values = inputs_dict["input_values"] + attention_mask = inputs_dict["attention_mask"] + + model = FlaxWav2Vec2ForPreTraining(config) + + outputs = model( + input_values, + attention_mask=attention_mask, + freeze_feature_encoder=False, + ) + + outputs_frozen = model( + input_values, + attention_mask=attention_mask, + freeze_feature_encoder=True, + ) + + # dummy loss function + def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8): + # compute cosine similarity of projected and projected_quantized states + cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon) + loss = cosine_sim.sum() + return loss + + # transform the loss function to get the gradients + grad_fn = jax.value_and_grad(compute_loss) + + # compute loss and gradients for unfrozen model + loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states) + + # compare to loss and gradients for frozen model + loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states) + + self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5) + self.assertEqual(grads.shape, grads_frozen.shape) + max_diff = np.amax(np.abs(grads - grads_frozen)) + self.assertLessEqual(max_diff, 1e-5) + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: