From fbd8c51ffcc23611f99f5a75fe232f1b010eeb72 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:18:36 +0200 Subject: [PATCH] Restore casting of masked_spec_embed (#30336) * fix Parameter dtype in audio models * restore casting of masked_spec_embed * restore casting of masked_spec_embed --- src/transformers/models/data2vec/modeling_data2vec_audio.py | 4 ++-- src/transformers/models/hubert/modeling_hubert.py | 4 ++-- src/transformers/models/sew/modeling_sew.py | 4 ++-- src/transformers/models/sew_d/modeling_sew_d.py | 4 ++-- src/transformers/models/speecht5/modeling_speecht5.py | 4 ++-- src/transformers/models/unispeech/modeling_unispeech.py | 4 ++-- .../models/unispeech_sat/modeling_unispeech_sat.py | 4 ++-- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 4 ++-- .../models/wav2vec2_bert/modeling_wav2vec2_bert.py | 4 ++-- .../models/wav2vec2_conformer/modeling_wav2vec2_conformer.py | 4 ++-- src/transformers/models/wavlm/modeling_wavlm.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 6df96aa49bb..3504258a58e 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -858,7 +858,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -868,7 +868,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index d17119426a5..257720cfe2a 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -1005,7 +1005,7 @@ class HubertModel(HubertPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1015,7 +1015,7 @@ class HubertModel(HubertPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 67e8dbdcbd0..cb6f82e2c7b 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -862,7 +862,7 @@ class SEWModel(SEWPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -872,7 +872,7 @@ class SEWModel(SEWPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 07da31afab5..84bf303cd52 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1388,7 +1388,7 @@ class SEWDModel(SEWDPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1398,7 +1398,7 @@ class SEWDModel(SEWDPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e2b38d01929..5caac417027 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -616,7 +616,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -626,7 +626,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 47dae1f3a8f..16e4af7d485 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1121,7 +1121,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1131,7 +1131,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2a882874fdb..7bdf33848c3 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -1139,7 +1139,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1149,7 +1149,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 0773af3a561..99b7f2c23e5 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1496,7 +1496,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1506,7 +1506,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 21c76048aaf..1bff2956f41 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -1087,7 +1087,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1097,7 +1097,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 2c4f5c289af..9109c15bb1b 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1273,7 +1273,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1283,7 +1283,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 94833e86a10..5d1a44c00a2 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -1158,7 +1158,7 @@ class WavLMModel(WavLMPreTrainedModel): if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), @@ -1168,7 +1168,7 @@ class WavLMModel(WavLMPreTrainedModel): min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis