Added with torch.no_grad() to Camembert integration test (#21544)

add with torch.no_grad() to Camembert integration test

Co-authored-by: Bibi <Bibi@katies-mac.local>
This commit is contained in:
Katie Le 2023-02-10 04:58:29 -05:00 committed by GitHub
parent f83942684d
commit 21a2d900ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,7 +39,8 @@ class CamembertModelIntegrationTest(unittest.TestCase):
device=torch_device,
dtype=torch.long,
) # J'aime le camembert !
output = model(input_ids)["last_hidden_state"]
with torch.no_grad():
output = model(input_ids)["last_hidden_state"]
expected_shape = torch.Size((1, 10, 768))
self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice.