[Whisper] Freeze params of encoder (#19527)

* [Whisper] Freeze params of encoder

* add tests
This commit is contained in:
Sanchit Gandhi 2022-10-13 09:50:02 +01:00 committed by GitHub
parent 504cd71a6b
commit bbd150e92f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 1 deletions

View File

@ -609,6 +609,11 @@ class WhisperEncoder(WhisperPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(
self,
input_features,
@ -991,6 +996,13 @@ class WhisperModel(WhisperPreTrainedModel):
def get_decoder(self):
return self.decoder
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
@ -1109,6 +1121,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self.model.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -182,9 +182,12 @@ class WhisperModelTester:
return input_lengths
def create_and_check_model_forward(self, config, inputs_dict):
def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
model = WhisperModel(config=config).to(torch_device).eval()
if freeze_encoder:
model.freeze_encoder()
input_features = inputs_dict["input_features"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_model_forward_with_frozen_encoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs, freeze_encoder=True)
def test_requires_grad_with_frozen_encoder(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.freeze_encoder()
try:
encoder_grads = [param.requires_grad for param in model.encoder.parameters()]
decoder_grads = [param.requires_grad for param in model.decoder.parameters()]
except AttributeError:
encoder_grads = [param.requires_grad for param in model.model.encoder.parameters()]
decoder_grads = [param.requires_grad for param in model.model.decoder.parameters()]
self.assertFalse(all(encoder_grads))
self.assertTrue(all(decoder_grads))
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)