Use assertAlmostEqual in BloomEmbeddingTest.test_logits (#19200)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-09-26 14:56:41 +02:00 committed by GitHub
parent 98af4f9b54
commit ea75e9f10e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -771,8 +771,8 @@ class BloomEmbeddingTest(unittest.TestCase):
output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
if cuda_available:
self.assertEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1)
self.assertEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2)
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6)
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
else:
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)