mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Wav2Vec2ForPretraining] Correct checkpoints wav2vec2 & fix tests (#12089)
* fix_torch_device_generate_test * remove @ * fix tests
This commit is contained in:
parent
61e191987d
commit
bc6f51e539
@ -349,6 +349,8 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@ -487,6 +489,8 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
def test_model_for_pretraining(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -677,10 +681,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_integration(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
"facebook/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
@ -723,10 +727,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||
|
||||
def test_inference_pretrained(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
"facebook/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
@ -761,7 +765,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# ... now compare to randomly initialized model
|
||||
|
||||
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
|
||||
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@ -785,9 +789,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
||||
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|
||||
|
||||
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
|
||||
def test_loss_pretraining(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base",
|
||||
"facebook/wav2vec2-base",
|
||||
attention_dropout=0.0,
|
||||
feat_proj_dropout=0.0,
|
||||
hidden_dropout=0.0,
|
||||
@ -796,7 +801,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device).train()
|
||||
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
"facebook/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
@ -829,6 +834,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
|
||||
|
||||
# check overall loss (contrastive loss + diversity loss)
|
||||
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
|
||||
expected_loss = 62.5170
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user