mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix tests (#14703)
This commit is contained in:
parent
68e53e6fcd
commit
7375758bee
@ -860,7 +860,8 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]]
|
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]],
|
||||||
|
device=torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
|
||||||
@ -970,7 +971,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(inputs=patches)
|
outputs = model(inputs=patches.to(torch_device))
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify logits
|
# verify logits
|
||||||
|
@ -99,17 +99,17 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# decoding
|
# decoding
|
||||||
decoded = tokenizer.decode(encoded_ids)
|
decoded = tokenizer.decode(encoded_ids)
|
||||||
self.assertEqual(decoded, "<cls>Unicode €.<sep>")
|
self.assertEqual(decoded, "[CLS]Unicode €.[SEP]")
|
||||||
|
|
||||||
encoded = tokenizer("e è é ê ë")
|
encoded = tokenizer("e è é ê ë")
|
||||||
encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5]
|
encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5]
|
||||||
self.assertEqual(encoded["input_ids"], encoded_ids)
|
self.assertEqual(encoded["input_ids"], encoded_ids)
|
||||||
# decoding
|
# decoding
|
||||||
decoded = tokenizer.decode(encoded_ids)
|
decoded = tokenizer.decode(encoded_ids)
|
||||||
self.assertEqual(decoded, "<cls>e è é ê ë<sep>")
|
self.assertEqual(decoded, "[CLS]e è é ê ë[SEP]")
|
||||||
|
|
||||||
# encode/decode, but with `encode` instead of `__call__`
|
# encode/decode, but with `encode` instead of `__call__`
|
||||||
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "<cls>e è é ê ë<sep>")
|
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "[CLS]e è é ê ë[SEP]")
|
||||||
|
|
||||||
def test_prepare_batch_integration(self):
|
def test_prepare_batch_integration(self):
|
||||||
tokenizer = self.perceiver_tokenizer
|
tokenizer = self.perceiver_tokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user