[Whiper] add get_input_embeddings to WhisperForAudioClassification (#22133)

* add `get_input_embeddings` to `WhisperForAudioClassification`

* add common tests

* fix another common test

* Update tests/models/whisper/test_modeling_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-03-13 19:46:01 +01:00 committed by GitHub
parent 987972377d
commit d979cf6efd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 5 deletions

View File

@ -767,6 +767,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def forward(
self,
input_features,
@ -1023,7 +1029,10 @@ class WhisperDecoder(WhisperPreTrainedModel):
)
# embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@ -1330,6 +1339,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
@ -1635,6 +1647,12 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
"""
self.encoder._freeze_parameters()
def get_input_embeddings(self) -> nn.Module:
return self.encoder.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module):
self.encoder.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return config, input_ids, None, max_length
# not implemented currently
def test_inputs_embeds(self):
pass
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
decoder_input_ids = inputs.pop("decoder_input_ids", None)
inputs.pop("decoder_attention_mask", None)
wte = model.get_input_embeddings()
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs)[0]
# training is not supported yet
def test_training(self):
@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self.assertTrue((outputs_embeds == outputs).all())
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
# Needs to override as the encoder input embedding is a Conv1d
def test_model_common_attributes(self):
pass
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d))
model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d))
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self):