Freeze FlaxWav2Vec2 Feature Encoder (#15873)

* Freeze FlaxWav2Vec2 Feature Encoder

* add to all module apply

* add backprop test
This commit is contained in:
Sanchit Gandhi 2022-03-03 14:17:13 +01:00 committed by GitHub
parent 7b3bd1f21a
commit 3c4fbc616f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 2 deletions

View File

@ -404,9 +404,11 @@ class FlaxWav2Vec2FeatureEncoder(nn.Module):
def setup(self):
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
def __call__(self, input_values):
def __call__(self, input_values, freeze_feature_encoder=False):
hidden_states = input_values[:, :, None]
hidden_states = self.conv_layers(hidden_states)
if freeze_feature_encoder:
hidden_states = jax.lax.stop_gradient(hidden_states)
return hidden_states
@ -875,6 +877,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -903,6 +906,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)
@ -939,9 +943,10 @@ class FlaxWav2Vec2Module(nn.Module):
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
extract_features = self.feature_extractor(input_values)
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
# make sure that no loss is computed on padded inputs
if attention_mask is not None:
@ -1101,6 +1106,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
outputs = self.wav2vec2(
@ -1110,6 +1116,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_feature_encoder=freeze_feature_encoder,
return_dict=return_dict,
)
@ -1232,6 +1239,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_enocder=False,
return_dict=None,
):
r"""
@ -1252,6 +1260,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
freeze_feature_encoder=freeze_feature_enocder,
return_dict=return_dict,
)
@ -1310,6 +1319,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -1342,6 +1352,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)

View File

@ -229,6 +229,47 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
def test_freeze_feature_encoder(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_values = inputs_dict["input_values"]
attention_mask = inputs_dict["attention_mask"]
model = FlaxWav2Vec2ForPreTraining(config)
outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=False,
)
outputs_frozen = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=True,
)
# dummy loss function
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
# compute cosine similarity of projected and projected_quantized states
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
loss = cosine_sim.sum()
return loss
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
# compute loss and gradients for unfrozen model
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
# compare to loss and gradients for frozen model
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
self.assertEqual(grads.shape, grads_frozen.shape)
max_diff = np.amax(np.abs(grads - grads_frozen))
self.assertLessEqual(max_diff, 1e-5)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: