mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Wav2Vec2] Make sure tensors are always bool for mask_indices (#13977)
* correct long to bool * up * correct code
This commit is contained in:
parent
11c043d27d
commit
58bf882579
@ -907,7 +907,7 @@ class HubertModel(HubertPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
if self.config.mask_feature_prob > 0 and self.training:
|
||||
@ -917,7 +917,7 @@ class HubertModel(HubertPreTrainedModel):
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
)
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)[
|
||||
:, None
|
||||
].expand(-1, sequence_length, -1)
|
||||
hidden_states[mask_feature_indices] = 0
|
||||
|
@ -1100,7 +1100,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
if self.config.mask_feature_prob > 0 and self.training:
|
||||
@ -1110,7 +1110,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
)
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)[
|
||||
:, None
|
||||
].expand(-1, sequence_length, -1)
|
||||
hidden_states[mask_feature_indices] = 0
|
||||
|
@ -738,6 +738,33 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(logits.shape, (4, 1498, 32))
|
||||
|
||||
def test_mask_time_feature_prob_ctc_single_batch(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2",
|
||||
mask_time_prob=0.2,
|
||||
mask_feature_prob=0.2,
|
||||
mask_time_length=2,
|
||||
mask_feature_length=2,
|
||||
)
|
||||
model.to(torch_device).train()
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||
)
|
||||
|
||||
batch_duration_in_seconds = [6]
|
||||
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
|
||||
|
||||
batch = processor(
|
||||
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
logits = model(
|
||||
input_values=batch["input_values"].to(torch_device),
|
||||
attention_mask=batch["attention_mask"].to(torch_device),
|
||||
).logits
|
||||
|
||||
self.assertEqual(logits.shape, (1, 1498, 32))
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
Loading…
Reference in New Issue
Block a user