This commit is contained in:
NielsRogge 2021-12-09 14:32:35 +01:00 committed by GitHub
parent 68e53e6fcd
commit 7375758bee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -860,7 +860,8 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]] [[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]],
device=torch_device,
) )
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
@ -970,7 +971,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():
outputs = model(inputs=patches) outputs = model(inputs=patches.to(torch_device))
logits = outputs.logits logits = outputs.logits
# verify logits # verify logits

View File

@ -99,17 +99,17 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# decoding # decoding
decoded = tokenizer.decode(encoded_ids) decoded = tokenizer.decode(encoded_ids)
self.assertEqual(decoded, "<cls>Unicode €.<sep>") self.assertEqual(decoded, "[CLS]Unicode €.[SEP]")
encoded = tokenizer("e è é ê ë") encoded = tokenizer("e è é ê ë")
encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5] encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5]
self.assertEqual(encoded["input_ids"], encoded_ids) self.assertEqual(encoded["input_ids"], encoded_ids)
# decoding # decoding
decoded = tokenizer.decode(encoded_ids) decoded = tokenizer.decode(encoded_ids)
self.assertEqual(decoded, "<cls>e è é ê ë<sep>") self.assertEqual(decoded, "[CLS]e è é ê ë[SEP]")
# encode/decode, but with `encode` instead of `__call__` # encode/decode, but with `encode` instead of `__call__`
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "<cls>e è é ê ë<sep>") self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "[CLS]e è é ê ë[SEP]")
def test_prepare_batch_integration(self): def test_prepare_batch_integration(self):
tokenizer = self.perceiver_tokenizer tokenizer = self.perceiver_tokenizer