mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[Whiper
] add get_input_embeddings
to WhisperForAudioClassification
(#22133)
* add `get_input_embeddings` to `WhisperForAudioClassification` * add common tests * fix another common test * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
987972377d
commit
d979cf6efd
@ -767,6 +767,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self._requires_grad = False
|
self._requires_grad = False
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.conv1
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value: nn.Module):
|
||||||
|
self.conv1 = value
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features,
|
input_features,
|
||||||
@ -1023,7 +1029,10 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
if input_ids is not None:
|
||||||
|
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||||
|
else:
|
||||||
|
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@ -1330,6 +1339,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.proj_out = new_embeddings
|
self.proj_out = new_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
def freeze_encoder(self):
|
def freeze_encoder(self):
|
||||||
"""
|
"""
|
||||||
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
|
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
|
||||||
@ -1635,6 +1647,12 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.encoder._freeze_parameters()
|
self.encoder._freeze_parameters()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.encoder.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value: nn.Module):
|
||||||
|
self.encoder.set_input_embeddings(value)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
return config, input_ids, None, max_length
|
return config, input_ids, None, max_length
|
||||||
|
|
||||||
# not implemented currently
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
decoder_input_ids = inputs.pop("decoder_input_ids", None)
|
||||||
|
inputs.pop("decoder_attention_mask", None)
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(**inputs)[0]
|
||||||
|
|
||||||
# training is not supported yet
|
# training is not supported yet
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
|
|
||||||
self.assertTrue((outputs_embeds == outputs).all())
|
self.assertTrue((outputs_embeds == outputs).all())
|
||||||
|
|
||||||
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
|
# Needs to override as the encoder input embedding is a Conv1d
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d))
|
||||||
|
model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3))
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d))
|
||||||
|
|
||||||
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
|
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user