Add AMD expectation to test_gpt2_sample (#38079)

This commit is contained in:
ivarflakstad 2025-05-12 16:51:21 +02:00 committed by GitHub
parent 4220039b29
commit 7eaa90b87b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,6 +20,7 @@ import pytest
from transformers import GPT2Config, is_torch_available from transformers import GPT2Config, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
Expectations,
cleanup, cleanup,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
@ -817,10 +818,14 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
EXPECTED_OUTPUT_STR = ( expected_outputs = Expectations(
"Today is a nice day and if you don't know anything about the state of play during your holiday" {
) ("rocm", None): 'Today is a nice day and we can do this again."\n\nDana said that she will',
self.assertEqual(output_str, EXPECTED_OUTPUT_STR) ("cuda", None): "Today is a nice day and if you don't know anything about the state of play during your holiday",
}
) # fmt: skip
EXPECTED_OUTPUT = expected_outputs.get_expectation()
self.assertEqual(output_str, EXPECTED_OUTPUT)
self.assertTrue( self.assertTrue(
all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))) all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
) # token_type_ids should change output ) # token_type_ids should change output