[Whiper] add get_input_embeddings to WhisperForAudioClassification (#22133)

* add `get_input_embeddings` to `WhisperForAudioClassification`

* add common tests

* fix another common test

* Update tests/models/whisper/test_modeling_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-03-13 19:46:01 +01:00 committed by GitHub
parent 987972377d
commit d979cf6efd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 5 deletions

View File

@ -767,6 +767,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
param.requires_grad = False param.requires_grad = False
self._requires_grad = False self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def forward( def forward(
self, self,
input_features, input_features,
@ -1023,7 +1029,10 @@ class WhisperDecoder(WhisperPreTrainedModel):
) )
# embed positions # embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@ -1330,6 +1339,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def freeze_encoder(self): def freeze_encoder(self):
""" """
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
@ -1635,6 +1647,12 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
""" """
self.encoder._freeze_parameters() self.encoder._freeze_parameters()
def get_input_embeddings(self) -> nn.Module:
return self.encoder.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module):
self.encoder.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(

View File

@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return config, input_ids, None, max_length return config, input_ids, None, max_length
# not implemented currently
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
decoder_input_ids = inputs.pop("decoder_input_ids", None)
inputs.pop("decoder_attention_mask", None)
wte = model.get_input_embeddings()
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs)[0]
# training is not supported yet # training is not supported yet
def test_training(self): def test_training(self):
@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self.assertTrue((outputs_embeds == outputs).all()) self.assertTrue((outputs_embeds == outputs).all())
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented # Needs to override as the encoder input embedding is a Conv1d
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d))
model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d))
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):