mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix onnx export of speech foundation models (#34224)
* added expanded attention/padding masks prior to indexing the hidden_states * consistency fix in WavLMForSequenceClassification --------- Co-authored-by: Nikos Antoniou <nikosantoniou@Nikos-MacBook-Pro.local>
This commit is contained in:
parent
f42084e641
commit
ff9141bb85
@ -1421,7 +1421,8 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -1629,7 +1629,8 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -882,15 +882,15 @@ class SEWEncoder(nn.Module):
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if attention_mask is not None:
|
||||
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
if self._use_flash_attention_2:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask] = 0.0
|
||||
hidden_states[~expand_attention_mask] = 0.0
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask] = 0.0
|
||||
|
||||
hidden_states[~expand_attention_mask] = 0.0
|
||||
input_lengths = (attention_mask.long()).sum(-1)
|
||||
# apply pooling formula to get real output_lengths
|
||||
output_lengths = input_lengths // self.config.squeeze_factor
|
||||
@ -1473,7 +1473,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -1175,7 +1175,8 @@ class SEWDEncoder(nn.Module):
|
||||
)
|
||||
else:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask.bool()] = 0.0
|
||||
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_attention_mask.bool()] = 0.0
|
||||
|
||||
input_lengths = (attention_mask.long()).sum(-1)
|
||||
# apply pooling formula to get real output_lengths
|
||||
@ -1721,7 +1722,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -1876,7 +1876,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -1886,7 +1886,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -2376,7 +2376,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -1359,7 +1359,8 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -878,7 +878,8 @@ class Wav2Vec2ConformerEncoder(nn.Module):
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask] = 0.0
|
||||
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_attention_mask] = 0.0
|
||||
|
||||
# extend attention_mask
|
||||
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
||||
@ -1791,7 +1792,8 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -691,7 +691,8 @@ class WavLMEncoder(nn.Module):
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask] = 0.0
|
||||
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_attention_mask] = 0
|
||||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
@ -776,7 +777,8 @@ class WavLMEncoderStableLayerNorm(nn.Module):
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens are not attended to
|
||||
hidden_states[~attention_mask] = 0
|
||||
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_attention_mask] = 0
|
||||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
@ -1508,7 +1510,8 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
||||
hidden_states[~expand_padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
Loading…
Reference in New Issue
Block a user