mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Freeze FlaxWav2Vec2 Feature Encoder (#15873)
* Freeze FlaxWav2Vec2 Feature Encoder * add to all module apply * add backprop test
This commit is contained in:
parent
7b3bd1f21a
commit
3c4fbc616f
@ -404,9 +404,11 @@ class FlaxWav2Vec2FeatureEncoder(nn.Module):
|
||||
def setup(self):
|
||||
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_values):
|
||||
def __call__(self, input_values, freeze_feature_encoder=False):
|
||||
hidden_states = input_values[:, :, None]
|
||||
hidden_states = self.conv_layers(hidden_states)
|
||||
if freeze_feature_encoder:
|
||||
hidden_states = jax.lax.stop_gradient(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -875,6 +877,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
freeze_feature_encoder: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -903,6 +906,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
freeze_feature_encoder,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
@ -939,9 +943,10 @@ class FlaxWav2Vec2Module(nn.Module):
|
||||
deterministic=True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
freeze_feature_encoder=False,
|
||||
return_dict=None,
|
||||
):
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
|
||||
|
||||
# make sure that no loss is computed on padded inputs
|
||||
if attention_mask is not None:
|
||||
@ -1101,6 +1106,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
deterministic=True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
freeze_feature_encoder=False,
|
||||
return_dict=None,
|
||||
):
|
||||
outputs = self.wav2vec2(
|
||||
@ -1110,6 +1116,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
freeze_feature_encoder=freeze_feature_encoder,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1232,6 +1239,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
||||
deterministic: bool = True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
freeze_feature_enocder=False,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
@ -1252,6 +1260,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
||||
output_hidden_states=output_hidden_states,
|
||||
mask_time_indices=mask_time_indices,
|
||||
deterministic=deterministic,
|
||||
freeze_feature_encoder=freeze_feature_enocder,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1310,6 +1319,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
freeze_feature_encoder: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -1342,6 +1352,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
freeze_feature_encoder,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
@ -229,6 +229,47 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def test_freeze_feature_encoder(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(config)
|
||||
|
||||
outputs = model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
freeze_feature_encoder=False,
|
||||
)
|
||||
|
||||
outputs_frozen = model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
freeze_feature_encoder=True,
|
||||
)
|
||||
|
||||
# dummy loss function
|
||||
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
|
||||
# compute cosine similarity of projected and projected_quantized states
|
||||
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
|
||||
loss = cosine_sim.sum()
|
||||
return loss
|
||||
|
||||
# transform the loss function to get the gradients
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
|
||||
# compute loss and gradients for unfrozen model
|
||||
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
|
||||
|
||||
# compare to loss and gradients for frozen model
|
||||
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
|
||||
|
||||
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
|
||||
self.assertEqual(grads.shape, grads_frozen.shape)
|
||||
max_diff = np.amax(np.abs(grads - grads_frozen))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
Loading…
Reference in New Issue
Block a user