fix test_compare_unprocessed_logit_scores (#39053)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-26 18:36:56 +02:00 committed by GitHub
parent 58c7689226
commit 23b7e73f05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3807,7 +3807,7 @@ class GenerationIntegrationTests(unittest.TestCase):
logits_gen = outputs.logits[0][0]
# assert that unprocessed logits from generate() are same as those from modal eval()
self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist())
torch.testing.assert_allclose(logits_fwd.tolist(), logits_gen.tolist())
def test_return_unprocessed_logit_scores(self):
# tell model to generate text and return unprocessed/unwarped logit scores