mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[Whisper] Freeze params of encoder (#19527)
* [Whisper] Freeze params of encoder * add tests
This commit is contained in:
parent
504cd71a6b
commit
bbd150e92f
@ -609,6 +609,11 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self._requires_grad = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
@ -991,6 +996,13 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def freeze_encoder(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
|
||||
not be updated during training.
|
||||
"""
|
||||
self.encoder._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
@ -1109,6 +1121,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.proj_out = new_embeddings
|
||||
|
||||
def freeze_encoder(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
|
||||
not be updated during training.
|
||||
"""
|
||||
self.model.encoder._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
@ -182,9 +182,12 @@ class WhisperModelTester:
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
|
||||
model = WhisperModel(config=config).to(torch_device).eval()
|
||||
|
||||
if freeze_encoder:
|
||||
model.freeze_encoder()
|
||||
|
||||
input_features = inputs_dict["input_features"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
|
||||
@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_model_forward_with_frozen_encoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs, freeze_encoder=True)
|
||||
|
||||
def test_requires_grad_with_frozen_encoder(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.freeze_encoder()
|
||||
|
||||
try:
|
||||
encoder_grads = [param.requires_grad for param in model.encoder.parameters()]
|
||||
decoder_grads = [param.requires_grad for param in model.decoder.parameters()]
|
||||
except AttributeError:
|
||||
encoder_grads = [param.requires_grad for param in model.model.encoder.parameters()]
|
||||
decoder_grads = [param.requires_grad for param in model.model.decoder.parameters()]
|
||||
|
||||
self.assertFalse(all(encoder_grads))
|
||||
self.assertTrue(all(decoder_grads))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user