[Wav2Vec2ForPretraining] Correct checkpoints wav2vec2 & fix tests (#12089)

* fix_torch_device_generate_test

* remove @

* fix tests
This commit is contained in:
Patrick von Platen 2021-06-09 20:41:59 +01:00 committed by GitHub
parent 61e191987d
commit bc6f51e539
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -349,6 +349,8 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
@slow
def test_model_from_pretrained(self):
@ -487,6 +489,8 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
def test_model_for_pretraining(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -677,10 +681,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_integration(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
"facebook/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
@ -723,10 +727,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
def test_inference_pretrained(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
"facebook/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
@ -761,7 +765,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# ... now compare to randomly initialized model
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
with torch.no_grad():
@ -785,9 +789,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
def test_loss_pretraining(self):
model = Wav2Vec2ForPreTraining.from_pretrained(
"patrickvonplaten/wav2vec2-base",
"facebook/wav2vec2-base",
attention_dropout=0.0,
feat_proj_dropout=0.0,
hidden_dropout=0.0,
@ -796,7 +801,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
model.to(torch_device).train()
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
"facebook/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
@ -829,6 +834,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
expected_loss = 62.5170
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)