Fix gelu test for torch 1.10 (#14167)

This commit is contained in:
Lysandre Debut 2021-10-26 22:20:51 -04:00 committed by GitHub
parent 8ddbfe9752
commit 1e53faeb2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,8 +29,8 @@ class TestActivations(unittest.TestCase):
def test_gelu_versions(self):
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu")
self.assertTrue(torch.eq(_gelu_python(x), torch_builtin(x)).all().item())
self.assertFalse(torch.eq(_gelu_python(x), gelu_new(x)).all().item())
self.assertTrue(torch.allclose(_gelu_python(x), torch_builtin(x)))
self.assertFalse(torch.allclose(_gelu_python(x), gelu_new(x)))
def test_get_activation(self):
get_activation("swish")