fix onnx export of speech foundation models (#34224)

* added expanded attention/padding masks prior to indexing the hidden_states

* consistency fix in WavLMForSequenceClassification

---------

Co-authored-by: Nikos Antoniou <nikosantoniou@Nikos-MacBook-Pro.local>
This commit is contained in:
Nikos Antoniou 2024-12-20 10:22:05 +02:00 committed by GitHub
parent f42084e641
commit ff9141bb85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 31 additions and 17 deletions

View File

@ -1421,7 +1421,8 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -1629,7 +1629,8 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -882,15 +882,15 @@ class SEWEncoder(nn.Module):
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
if self._use_flash_attention_2:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
hidden_states[~expand_attention_mask] = 0.0
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
hidden_states[~expand_attention_mask] = 0.0
input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
output_lengths = input_lengths // self.config.squeeze_factor
@ -1473,7 +1473,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -1175,7 +1175,8 @@ class SEWDEncoder(nn.Module):
)
else:
# make sure padded tokens output 0
hidden_states[~attention_mask.bool()] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask.bool()] = 0.0
input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
@ -1721,7 +1722,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -1876,7 +1876,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -1886,7 +1886,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -2376,7 +2376,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -1359,7 +1359,8 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -878,7 +878,8 @@ class Wav2Vec2ConformerEncoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0.0
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
@ -1791,7 +1792,8 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)

View File

@ -691,7 +691,8 @@ class WavLMEncoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@ -776,7 +777,8 @@ class WavLMEncoderStableLayerNorm(nn.Module):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@ -1508,7 +1510,8 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)