mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix tests (#14703)
This commit is contained in:
parent
68e53e6fcd
commit
7375758bee
@ -860,7 +860,8 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]]
|
||||
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
|
||||
@ -970,7 +971,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs=patches)
|
||||
outputs = model(inputs=patches.to(torch_device))
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
|
@ -99,17 +99,17 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
# decoding
|
||||
decoded = tokenizer.decode(encoded_ids)
|
||||
self.assertEqual(decoded, "<cls>Unicode €.<sep>")
|
||||
self.assertEqual(decoded, "[CLS]Unicode €.[SEP]")
|
||||
|
||||
encoded = tokenizer("e è é ê ë")
|
||||
encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5]
|
||||
self.assertEqual(encoded["input_ids"], encoded_ids)
|
||||
# decoding
|
||||
decoded = tokenizer.decode(encoded_ids)
|
||||
self.assertEqual(decoded, "<cls>e è é ê ë<sep>")
|
||||
self.assertEqual(decoded, "[CLS]e è é ê ë[SEP]")
|
||||
|
||||
# encode/decode, but with `encode` instead of `__call__`
|
||||
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "<cls>e è é ê ë<sep>")
|
||||
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "[CLS]e è é ê ë[SEP]")
|
||||
|
||||
def test_prepare_batch_integration(self):
|
||||
tokenizer = self.perceiver_tokenizer
|
||||
|
Loading…
Reference in New Issue
Block a user